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Preface 


I am writing a longer [book] than usual because there is not enough time to write a short one. 
(Blaise Pascal, paraphrased.) 


This book is a sequel to [Mur22]. That book mostly focused on techniques for learning functions 
f : X — VY, where f is some nonlinear model, such as a deep neural network, ¥ is the set of possible 
inputs (typically ¥ =R”), and Y = {1,...,C} represents the set of labels for classification problems 
or Y = R for regression problems. Judea Pearl, a well known AI researcher, has called this kind of 
ML a form of “glorified curve fitting” (quoted in [Har18]). 

In this book, we expand the scope of ML to encompass more challenging problems. For example, we 
consider training and testing under different distributions; we consider generation of high dimensional 
outputs, such as images, text and graphs, so the output space is, say, Y = R?56x256. we discuss 
methods for discovering “insights” about data, based on latent variable models; and we discuss how 
to use probabilistic models for causal inference and decision making under uncertainty. 

We assume the reader has some prior exposure to ML and other relevant mathematical topics 
(e.g., probability, statistics, linear algebra, optimization). This background material is covered in the 
prequel to this book, [Mur22], amongst other sources (e.g., [Lin+21b; DFO20]). 

Python code (mostly in JAX) to reproduce nearly all of the figures can be found online. In 
particular, if a figure caption says “Generated by gauss_plot_2d.ipynb”, then you can find the 
corresponding Jupyter notebook at probml.github.io/notebooks#gauss_plot_ 2d.ipynb. Clicking on 
the figure link in the pdf version of the book will take you to this list of notebooks. Clicking on 
the notebook link will open it inside Google Colab, which will let you easily reproduce the figure 
for yourself, and modify the underlying source code to gain a deeper understanding of the methods. 
(Colab gives you access to a free GPU, which is useful for some of the more computationally heavy 
demos.) 

In addition to the online code, at probml.github.io/supp there is some additional supplementary 
online content. This contains additional material which was excluded from the main book for space 
reasons. For exercises (and solutions) related to the topics in this book, see [Gut22]. 
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1 Introduction 


“Intelligence is not just about pattern recognition and function approximation. It’s about 
modeling the world”. — Josh Tenenbaum, NeurIPS 2021. 


Much of current machine learning focuses on the task of mapping inputs to outputs (i.e., approxi- 
mating functions of the form f : Y — YV), often using “deep learning” (see e.g., [LBH15; Sch14; 
Sej20; BLH21]). Judea Pearl, a well known AI researcher, has called this “glorified curve fitting” 
(quoted in [Har18]). This is a little unfair, since when ¥ and/or y are high-dimensional spaces — 
such as images, sentences, graphs, or sequences of decisions/actions — then the term “curve fitting” is 
rather misleading, since one-dimensional intuitions often do not work in higher dimensional settings 
(see e.g., [BPL21a]). Nevertheless, the quote gets at what many feel is lacking in current attempts 
to “solve AT” using machine learning techniques, namely that they are too focused on prediction 
of observable patterns, and not focused enough on “understanding” the underlying latent structure 
behind these patterns. 

Gaining a “deep understanding” of the structure behind the observed data is necessary for advancing 
science, as well as for certain applications, such as healthcare (see e.g., [DD22]), where identifying 
the root causes or mechanisms behind various diseases is the key to developing cures. In addition, 
such “deep understanding” is necessary in order to develop robust and efficient systems. By “robust” 
we mean methods that work well even if there are unexpected changes to the data distribution to 
which the system is applied, which is an important concern in many areas, such as robotics (see e.g., 
[Roy+21]). By “efficient” we generally mean data or statistically efficient i.e., methods that can learn 
quickly from small amounts of data (c.f., [Lu+-21b]). This is important since data can be limited 
in some domains, such as healthcare and robotics, even though it is abundant in other domains, 
such as language and vision, due to the ability to scrape the internet. We are also interested in 
computationally efficient methods, although this is a secondary concern as computing power continues 
to grow. (We also note that this trend has been instrumental to much of the recent progress in AI, 
as noted in [Sut19].) 

To develop robust and efficient systems, this book adopts a model-based approach, in which we try 
to learn parsimonious representations of the underlying “data generating process” (DGP) given 
samples from one or more datasets (c.f., [Lak+17; Win+19; Sch20; Ben+21a; Cun22; MTS22]). This 
is in fact similar to the scientific method, where we try to explain (features of) the observations by 
developing theories or models. One way to formalize this process is in terms of Bayesian inference 
applied to probabilistic models, as argued in [Jay03; Box80; GS13]. We discuss inference algorithms 
in detail in Part II of the book.! But before we get there, in Part I we cover some relevant background 


1. Note that, in the deep learning community, the term “inference” means applying a function to some inputs to 
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material that will be needed. (This part can be skipped by readers who are already familiar with 
these basics.) 

Once we have a set of inference methods in our toolbox (some of which may be as simple as 
computing a maximum likelihood estimate using an optimization method, such as stochastic gradient 
descent) we can turn our focus to discussing different kinds of models. The choice of model depends 
on our task, the kind and amount of data we have, and our metric(s) of success. We will broadly 
consider four main kinds of task: prediction (e.g., classification and regression), generation (e.g., of 
images or text), discovery (of “meaningful structure” in data), and control (optimal decision making). 
We give more details below. 

In Part III, we discuss models for prediction. These models are conditional distributions of the 
form p(y|x), where x € X is some input (often high dimensional), and y € yY is the desired output 
(often low dimensional). In this part of the book, we assume there is one right answer that we want 
to predict, although we may be uncertain about it. 

In Part IV, we discuss models for generation. These models are distributions of the form p(x) or 
p(a|c), where c are optional conditioning inputs, and where there may be multiple valid outputs. 
For example, given a text prompt c, we may want to generate a diverse set of images a that “match” 
the caption. Evaluating such models is harder than in the prediction setting, since it is less clear 
what the desired output should be. 

In Part V, we discuss latent variable models, which are joint models of the form p(z, x) = p(z)p(a|z), 
where z is the hidden state and æ are the observations that are assumed to be generated from z. 
The goal is to compute p(z|a), in order to uncover some (hopefully meaningful / useful) underlying 
state or patterns in the observed data. We also consider methods for trying to discover patterns 
learned implicitly by predictive models of the form p(y|x), without relying on an explicit generative 
model of the data. 

Finally, in Part VI, we discuss models and algorithms which can be used to make decisions under 
uncertainty. This naturally leads into the very important topic of causality, with which we close the 
book. 

In view of the broad scope of the book, we cannot go into detail on every topic. However, we 
have attempted to cover all the basics. In some cases, we also provide a “deeper dive” into the 
research frontier (as of 2022). We hope that by bringing all these topics together, you will find it 
easier to make connections between all these seemingly disparate areas, and can thereby deepen your 
understanding of the field of machine learning. 


compute the output. This is unrelated to Bayesian inference, which is concerned with the much harder task of inverting 
a function, and working backwards from observed outputs to possible hidden inputs (causes). The latter is more closely 
related to what the deep learning community calls “training”. 
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PART I 


Fundamentals 


2 Probability 


The mathematical rules of probability theory are not merely rules for calculating frequencies 
of ‘random variables’; they are also the unique rules for conducting inference (i.e. plausible 
reasoning) of any kind. — E .T. Jaynes [Jay03]. 


2.1 Introduction 


We assume the reader is already familiar with basic probability theory. For example, see [Cha21], 
or chapter 2 of the prequel to this book, [Mur22]. In this chapter, we briefly review some of this 
material, to make the book self-contained. 


2.2 Some common probability distributions 


There are a wide variety of probability distributions that are used for various kinds of models. We 
summarize some of the more commonly used ones in the sections below. See Supplementary Chapter 2 
for more information, and https: //ben18785.shinyapps.io/distribution-zoo/ for an interactive 
visualization. 


2.2.1 Discrete distributions 


In this section, we discuss some discrete distributions defined on subsets of the (non-negative) integers. 


2.2.1.1 Bernoulli and binomial distributions 


Let x € {0,1,..., N}. The binomial distribution is defined by 
. A N x N-z 
Bin(z|N, u) | JEG- p) (2.1) 


where ie ) = WEE is the number of ways to choose k items from N (this is known as the binomial 


coefficient, and is pronounced “N choose k”). 
If N = 1, so x € {0,1}, the binomial distribution reduces to the Bernoulli distribution: 


1— ifzr=0 
Ber(zju)=4 A L7 (2.2) 
H ae 


where u = E [a] = p(x = 1) is the mean. 
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2.2.1.2 Categorical and multinomial distributions 


If the variable is discrete-valued, x € {1,..., K}, we can use the categorical distribution: 
K 
Cat(«|0) £ [] g=” (2.3) 


Alternatively, we can represent the K-valued variable x with the one-hot binary vector æ, which lets 
us write 


K 
Cat (a6) = | | 07" (2.4) 


k=1 


If the k’th element of æ counts the number of times the value k is seen in N = Sue x, trials, then 
we get the multinomial distribution: 


K 


M(a\|N,0) ê ( N ) [le (2.5) 


H1..-LK kat 


where the multinomial coefficient is defined as 


N , N! 
A 2. 
ae kil... Km! 26) 


2.2.1.3 Poisson distribution 


Suppose X € {0,1,2,...}. We say that a random variable has a Poisson distribution with parameter 
à > 0, written X ~ Poi(A), if its pmf (probability mass function) is 
La A? 


Poi(zJA) = em E (2.7) 


where is the mean (and variance) of x. 


2.2.1.4 Negative binomial distribution 


Suppose we have an “urn” with N balls, R of which are red and B of which are blue. Suppose we 
perform sampling with replacement until we get n > 1 balls. Let X be the number of these that 
are blue. It can be shown that X ~ Bin(n,p), where p = B/N is the fraction of blue balls; thus X 
follows the binomial distribution, discussed in Section 2.2.1.1. 

Now suppose we consider drawing a red ball a “failure”, and drawing a blue ball a “success”. Suppose 
we keep drawing balls until we observe r failures. Let X be the resulting number of successes (blue 
balls); it can be shown that X ~ NegBinom(r,p), which is the negative binomial distribution 
defined by 


-1 
NegBinom(z|r, p) * ‘ aa ) (1 — p)"p” (2.8) 
x 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


for x € {0,1,2,...}. (If r is real-valued, we replace (A) with ae. exploiting the fact that 


(a — 1)! =T(z).) 
This distribution has the following moments: 


a pr pr 

This two parameter family has more modeling more flexibility than the Poisson distribution, since it 
can represent the mean and variance separately. This is useful e.g., for modeling “contagious” events, 
which have positively correlated occurrences, causing a larger variance than if the occurrences were 
independent. In fact, The Poisson distribution is a special case of the negative binomial, since it 
can be shown that Poi(A) = lim,;_,.. NegBinom(r, 15) Another special case is when r = 1; this is 
called the geometric distribution. 


2.2.2 Continuous distributions on R 


In this section, we discuss some univariate distributions defined on the reals, p(x) for x € R. 


2.2.2.1 Gaussian (Normal) 


The most widely used univariate distribution is the Gaussian distribution, also called the normal 
distribution. (See [Mur22, Sec 2.6.4] for a discussion of these names.) The pdf (probability density 
function) of the Gaussian is given by 


1 2 2 
N (z|u, o°) £ e7 302 (PH) (2.10) 
V27n0? 


where V270? is the normalization constant needed to ensure the density integrates to 1. The 
parameter u encodes the mean of the distribution, which is the same as the mode, since the 
distribution is unimodal. The parameter g? encodes the variance. Sometimes we talk about the 
precision of a Gaussian, by which we mean the inverse variance: A = 1/07. A high precision means 
a narrow distribution (low variance) centered on p. 

The cumulative distribution function or cdf of the Gaussian is defined as 


D(z; p,02) 2 i N(z|tt,0?)d2 (2.11) 


If u = 0 and ø = 1 (known as the standard Normal distribution), we just write ®(x). 


2.2.2.2 Half-normal 


For some problems, we want a distribution over non-negative reals. One way to create such a 
distribution is to define Y = |X|, where X ~ N(0,07). The induced distribution for Y is called the 
half-normal distribution, which has the pdf 


N(yio) È 2N'(y)0,02) = YZ exp (-&) s20 (2.12) 


This can be thought of as the (0,0?) distribution “folded over” onto itself. 
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Figure 2.1: (a) The pdf’s for a N(0,1), Ti(0,1) and Laplace(0, 1/ v2). The mean is 0 and the variance is 1 
for both the Gaussian and Laplace. The mean and variance of the Student distribution is undefined when v = 1. 
(b) Log of these pdf’s. Note that the Student distribution is not log-concave for any parameter value, unlike 
the Laplace distribution. Nevertheless, both are unimodal. Generated by student_laplace_pdf_plot.ipynb. 


2.2.2.3 Student t distribution 


One problem with the Gaussian distribution is that it is sensitive to outliers, since the probability 
decays exponentially fast with the (squared) distance from the center. A more robust distribution is 
the Student ¢-distribution, which we shall call the Student distribution for short. Its pdf is as 


follows: 
-(4) 
1 1 /x—u i 
Telme) = 3 1+4 (2) (2.13) 
vvro’r (#2) = ly 
2 


where p is the mean, o > 0 is the scale parameter (not the standard deviation), and v > 0 is called 
the degrees of freedom (although a better term would be the degree of normality [Kru13], since 
large values of v make the distribution act like a Gaussian). Here ['(a) is the gamma function 
defined by 


r(a) 4 | rle” dx (2.15) 
0 


and B(a,b) is the beta function, defined by 


a T(E) 


Ba) TE 


(2.16) 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


2.2.2.4 Cauchy distribution 


If v = 1, the Student distribution is known as the Cauchy or Lorentz distribution. Its pdf is defined 


by 
1+ i (2.17) 


where Z = 78(4,4) = yr. This distribution is notable for having such heavy tails that the integral 
that defines the mean does not converge. 
The half Cauchy distribution is a version of the Cauchy (with mean 0) that is “folded over” on 


itself, so all its probability density is on the positive reals. Thus it has the form 


1 
C(z|u, y) = 7 


14 Əl 7 (2.18) 


2.2.2.5 Laplace distribution 


2 
C+(z/y) = Ty 


Another distribution with heavy tails is the Laplace distribution, also known as the double sided 
exponential distribution. This has the following pdf: 


1 a 
Laplace(x|y1,b) ê = exp ( z - "i (2.19) 


Here p is a location parameter and b > 0 is a scale parameter. See Figure 2.1 for a plot. 


2.2.2.6 Sub-Gaussian and super-Gaussian distributions 


There are two main variants of the Gaussian distribution, known as super-Gaussian or leptokurtic 
( “Lepto” is Greek for “narrow”) and sub-Gaussian or platykurtic (“Platy” is Greek for “broad”). 
These distributions differ in terms of their kurtosis, which is a measure of how heavy or light their 
tails are (i.e., how fast the density dies off to zero away from its mean). More precisely, the kurtosis 
is defined as 
a H4 ) [(Z E p)“] 

kurt(z) si EKZ- o (2.20) 
where ø is the standard deviation, and u4 is the 4’th central moment. (Thus pı = p is the mean, 
and u2 = o° is the variance.) For a standard Gaussian, the kurtosis is 3, so some authors define the 
excess kurtosis as the kurtosis minus 3. 

A super-Gaussian distribution (e.g., the Laplace) has positive excess kurtosis, and hence heavier 
tails than the Gaussian. A sub-Gaussian distribution, such as the uniform, has negative excess 
kurtosis, and hence lighter tails than the Gaussian. See Figure 2.2 for an illustration. 


2.2.3 Continuous distributions on Rt 


In this section, we discuss some univariate distributions defined on the positive reals, p(x) for x € RH. 
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Figure 2.2: Illustration of Gaussian (blue), sub-Gaussian (uniform, green) and super-Gaussian (Laplace, red) 
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distributions in 1d and 2d. Generated by sub_ super_gauss_ plot.cpynb. 


Gamma distributions 


1.754 4 — 


= a=1.0,b=2.0 
= a=1.5,b=2.0 
= a=2.0,b=2.0 


a=1.0,b=1.0 
a=1.5,b=1.0 
a=2.0,b=1.0 


(a) 


Figure 2.3: (a) Some gamma distributions. Ifa <1, the mode is at 0, otherwise the mode is away from 0. 
As we increase the rate b, we reduce the horizontal scale, thus squeezing everything leftwards and upwards. 
Generated by gamma_ dist_ plot.ipynb. (b) Some beta distributions. If a <1, we get a “spike” on the left, 
and ifb < 1, we get a “spike” on the right. if a= b = 1, the distribution is uniform. Ifa > 1 and b > 1, the 
distribution is unimodal. Generated by beta_dist_plot.ipynb. 


Beta distributions 


0.0 0.2 0.4 0.6 0.8 1.0 


(b) 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


2.2.3.1 Gamma distribution 


The gamma distribution is a flexible distribution for positive real valued rv’s, x > 0. It is defined 
in terms of two parameters, called the shape a > 0 and the rate b > 0: 


be 
Ga(z|shape = a,rate = b) = gee (2.21) 
Ta) 
Sometimes the distribution is parameterized in terms of the rate a and the scale s = 1/b: 
1 
Ga(a|shape = a,scale = s) & ———*~le~*/s (2.22) 
sT (a) 


See Figure 2.3a. 


2.2.3.2 Exponential distribution 

The exponential distribution is a special case of the gamma distribution and is defined by 
Expon(2|A) = Ga(z|shape = 1, rate = A) (2.23) 

This distribution describes the times between events in a Poisson process, i.e. a process in which 

events occur continuously and independently at a constant average rate À. 

2.2.3.3 Chi-squared distribution 


The Chi-squared distribution is a special case of the gamma distribution and is defined by 
1 
X(x) ê Ga(z|shape = 5 tate =>) (2.24) 


where v is called the degrees of freedom. This is the distribution of the sum of squared Gaussian 
random variables. More precisely, if Z; ~ M (0,1), and S = $; Z?, then S ~ x2. Hence if 
X ~ N (0,0?) then X? ~ 07x37. Since E [y7] = 1 and V [x7] = 2, we have 


i [X?] =a", V [X?] = 204 (2.25) 


2.2.3.4 Inverse gamma 

The inverse Gamma distribution, denoted Y ~ IG(a, b), is the distribution of Y = 1/X assuming 

X ~ Ga(a,b). This pdf is defined by 
pa 

(a) 

The mean only exists if a > 1. The variance only exists if a > 2. 


The scaled inverse chi-squared distribution is a reparameterization of the inverse Gamma 
distribution: 


A 


IG(a|shape = a,scale = b) ee (2.26) 


y2(a|v, o?) = IG(z|shape = 5 scale z2) (2.27) 


agi (OY a 2) ea) 
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Pareto Distribution Log Pareto Distribution 


p(a|m, k) 
log p(|m, k) 


x log(x) 
(a) (b) 


Figure 2.4: (a) The Pareto pdf Pareto(x|k,m). (b) Same distribution on a log-log plot. Generated by 
pareto_ dist_ plot.ipynb. 


The regular inverse chi-squared distribution, written y,7(x), is the special case where vo? = 1 (i.e., 
o? =1/v). This corresponds to IG(x|shape = 1/2, scale = $). 


2.2.3.5 Pareto distribution 
The Pareto distribution has the following pdf: 


Pareto(z|m, k) = nm I (x >m) (2.29) 
x 
See Figure 2.4(a) for some plots. We see that x must be greater than the minimum value m, but 
then rapidly decays after that. If we plot the distribution on a log-log scale, it forms the straight line 
log p(x) = —alog x + log(c), where a = («K + 1) and c = km": see Figure 2.4(b) for an illustration. 

When m = 0, the distribution has the form p(x) = Ka~*. This is known as a power law. If 
a = 1, the distribution has the form p(x) « 1/2; if we interpret x as a frequency, this is called a 1/f 
function. 

The Pareto distribution is useful for modeling the distribution of quantities that exhibit heavy 
tails or long tails, in which most values are small, but there are a few very large values. Many forms 
of data exhibit this property. (/ACL16] argue that this is because many datasets are generated by a 
variety of latent factors, which, when mixed together, naturally result in heavy tailed distributions.) 
We give some examples below. 


Modeling wealth distributions 


The Pareto distribution is named after the Italian economist and sociologist Vilfredo Pareto. He 
created it in order to model the distribution of wealth across different countries. Indeed, in economics, 
the parameter « is called the pareto index. If we set « = 1.16, we recover the 80-20 rule, which 
states that 80% of the wealth of a society is held by 20% of the population. 1 


1. In fact, wealth distributions are even more skewed than this. For example, as of 2014, 80 billion- 
aires now have as much wealth as 3.5 billion people! (Source: http://www.pbs.org/newshour/making-sense/ 
wealthiest-getting-wealthier-lobbying-lot.) Such extreme income inequality exists in many plutocratic countries, 
including the USA (see e.g., [HP 10]). 
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word counts 
—— linear prediction 
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Figure 2.5: A log-log plot of the frequency vs the rank for the words in H. G. Wells’ The Time Machine. 
Generated by zipfs_law_plot.ipynb. Adapted from a figure from [Zha+20a, Sec 8.3]. 


Zipf’s law 


Zipf’s law says that the most frequent word in a language (such as “the”) occurs approximately 
twice as often as the second most frequent word (“of”), which occurs twice as often as the fourth 
most frequent word, etc. This corresponds to a Pareto distribution of the form 


pia=r) x Kr’ (2.30) 


where r is the rank of word x when sorted by frequency, and « and a are constants. If we set a = 1, 
we recover Zipf’s law.? Thus Zipf’s law predicts that if we plot the log frequency of words vs their 
log rank, we will get a straight line with slope —1. This is in fact true, as illustrated in Figure 2.5.° 
See [Ada00] for further discussion of Zipf’s law, and Section 2.6.2 for a discussion of language models. 


2.2.4 Continuous distributions on [0, 1] 


In this section, we discuss some univariate distributions defined on the [0, 1] interval. 


2.2.4.1 Beta distribution 


The beta distribution has support over the interval [0,1] and is defined as follows: 


Beta(a|a, b) = ia (2.31) 


B(a,b) 


We require a,b > 0 to ensure the distribution is integrable (i.e., to ensure B(a,b) exists). If 
a = b = 1, we get the uniform distribution. If a and b are both less than 1, we get a bimodal 
distribution with “spikes” at 0 and 1; if a and b are both greater than 1, the distribution is unimodal. 
See Figure 2.3b. 


2. For example, p(x = 2) = K271 = 2k47! = 2p(x = 4). 
3. We remove the first 10 words from the plot, since they don’t fit the prediction as well. 
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Full Diagonal Spherical 


5.0 0.16 
0.14 
2.5 0.12 
0.10 & 
0.0 0.08 = 
0.06 = 
—2.5 0.04 ™ 
0.02 
—5.0 0.00 
= 0 5 5 0 5 5 0 5 


(a) (b) (c) 


Figure 2.6: Visualization of a 2d Gaussian density in terms of level sets of constant probability density. (a) A 
full covariance matrix has elliptical contours. (b) A diagonal covariance matrix is an axis aligned ellipse. (c) 
A spherical covariance matrix has a circular shape. Generated by gauss_plot_ 2d.ipynb. 


2.2.5 The multivariate Gaussian (normal) distribution 

The most widely used joint probability distribution for continuous random variables is the multi- 
variate Gaussian or multivariate normal (MVN). This is mostly because it is mathematically 
convenient, but also because the Gaussian assumption is fairly reasonable in many cases. 

2.2.5.1 Definition 

The MVN density is defined by the following: 


1 
54 


A 1 = 


where u = E [a] € RP is the mean vector, and © = Cov [æ] is the D x D covariance matrix. The 
normalization constant Z = (27)?/?|%|1/? just ensures that the pdf integrates to 1. The expression 
inside the exponential (ignoring the factor of —0.5) is the squared Mahalanobis distance between 
the data vector x and the mean vector m, given by 


d(x, p)? = (æ — p) E~ (æ — p) (2.33) 


In 2d, the MVN is known as the bivariate Gaussian distribution. Its pdf can be represented as 
x ~ N (u, ©), where x € R?, w € R? and 


oi oa G poo 
X= 1 12 | _ 1 172 2.34 
(4 o) = (poo o (aan 


where the correlation coefficient is given by p £ riz 

Figure 2.6 plots some MVN densities in 2d for heen different kinds of covariance matrices. A full 
covariance matrix has D(D + 1)/2 parameters, where we divide by 2 since © is symmetric. A 
diagonal covariance matrix has D parameters, and has Os in the off-diagonal terms. A spherical 
covariance matrix, also called isotropic covariance matrix, has the form © = o7Ip, so it only 
has one free parameter, namely o?. 
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PROBABILITY 


DENSITY Highest 


TYPICAL Density 


SET 


Samples 


DISTANCE FROM MODE 


Olo D4) 
(a) (b) 


Figure 2.7: (a) Cartoon illustration of why the typical set of a Gaussian is not centered at the mode of the 
distribution. (b) Illustration of the typical set of a Gaussian, which is concentrated in a thin annulus of 
thickness oD'/* and distance oD? from the origin. We also show an image with the highest density (the 
all gray image on the left). as well as some high probability samples (the speckle noise images on the right). 
From Figure 1 of [Nal+19a]. Used with kind permission of Eric Nalisnick. 


2.2.5.2 Gaussian shells 


Multivariate Gaussians can behave rather counterintuitively in high dimensions. In particular, we 
can ask: if we draw samples x ~ N (0,Ip), where D is the number of dimensions, where do we 
expect most of the x to lie? Since the peak (mode) of the pdf is at the origin, it is natural to expect 
most samples to be near the origin. However, in high dimensions, the typical set of a Gaussian is 
a thin shell or annulus with a distance from origin given by r = VD and a thickness of O(s D3). 
The intuitive reason for this is as follows: although the density decays as er! 2 meaning density 
decreases from the origin, the volume of a sphere grows as r?, meaning volume increases from the 
origin, and since mass is density times volume, the majority of points end up in this annulus where 
these two terms “balance out”. This is called the “Gaussian soap bubble” phenomenon, and is 
illustrated in Figure 2.7.4 

To see why the typical set for a Gaussian is concentrated in a thin annulus at radius VD, consider 


the squared distance of a point æ from the origin, d(a) = \/37?., x2, where z; ~ N(0,1). The 
expected squared distance is given by E [d?] = S2 E [x?] = D, and the variance of the squared 


distance is given by V [d?] = SA V Ea = D. As D grows, the coefficient of variation (i.e., the SD 
relative to the mean) goes to zero: 


ale] VD 
2 ee Aao (200) 


Thus the expected square distance concentrates around D, so the expected distance concentrates 
around E [d(a)] = V D. See [Ver18] for a more rigorous proof, and Main Section 5.2.3 for a discussion 
of typical sets. 


4. For a more detailed explanation, see this blog post by Ferenc Huszar: https://www.inference.vc/ 
high-dimensional-gaussian-distributions-are-soap-bubble/. 
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To see what this means in the context of images, in Figure 2.7b, we show some grayscale images 
that are sampled from a Gaussian of the form (1,071), where ps corresponds to the all-gray image. 
However, it is extremely unlikely that randomly sampled images would be close to all-gray, as shown 
in the figure. 
2.2.5.3 Marginals and conditionals of a MVN 
Let us partition our vector of random variables x into two parts, x; and £2, so 

=(M1)) 3. (2% že 2.36 

p= (Ht), ea (Se oe (2.36) 
The marginals of this distribution are given by the following (see Section 2.2.5.5 for the proof): 

p(zı) = JNa, E)dzz = N(x1|u7, ET) = N (ale, E11) (2.37) 


p(z2) = [Xow E)dzı = N(x2| my, Uy’) = N (w2| Mg, £22) (2.38) 
The conditional distributions can be shown to have the following form (see Section 2.2.5.5 for the 
proof): 
p(x1|x2) = N (a1 |p%j9, Zij) = N (ailpy + £12533 (£2 — Ho), Daa — X12933 X21) (2.39) 
p(a2|a1) = N (a2|M5)1, E31) = N (22| + 22151 (£1 — oy), X22 — X21 E11 X12) (2.40) 
Note that the posterior mean of p(a,|a2) is a linear function of £2, but the posterior covariance is 
independent of x2; this is a peculiar property of Gaussian distributions. 
2.2.5.4 Information (canonical) form 


It is common to parameterize the MVN in terms of the mean vector u and the covariance matrix X. 
However, for reasons which are explained in Main Section 2.3.2.5, it is sometimes useful to represent 
the Gaussian distribution using canonical parameters or natural parameters, defined as 


ASDI, n 2D lp (2.41) 
The matrix A = X`! is known as the precision matrix, and the vector 7 is known as the 
precision-weighted mean. We can convert back to the more familiar moment parameters using 

u=, S=A"* (2.42) 
Hence we can write the MVN in canonical form (also called information form) as follows: 

1 
N.(a\n, A) Ê cexp (=n — Ze") (2.43) 


a _exp(—3n" An) 
(2r)P/24/det(A7 t) 


c (2.44) 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


where we use the notation V.() to distinguish it from the standard parameterization M(). For more 
information on moment and natural parameters, see Main Section 2.3.2.5. 

It is also possible to derive the marginalization and conditioning formulas in information form (see 
Section 2.2.5.6 for the derivation). For the marginals we have 

p(x) = Neli nT, AT) = Næ — A2433 Np, Ay Ay2A53 A21) (2.45) 

p(@2) = Ne(wa|n3’, A7) = Ne(w2|m2 — A2147 M, A22 — Aoi AF7 A12) (2.46) 


For the conditionals we have 


p(x1|@2) = Ne(x1|N{\2, Atj2) = Ne(@1]m, — A1222, A11) (2.47) 
p(®2|@1) = N¢(x2|75)1, A31) = Ne(@2/M2 — Aoi x1, A22) (2.48) 


Thus we see that marginalization is easier in moment form, and conditioning is easier in information 
form. 


2.2.5.5 Derivation: moment form 


In this section, we derive Equation (2.37) and Equation (2.39) for marginalizing and conditioning an 
MVN in moment form. 

Before we dive in, we need to introduce the following result, for the inverse of a partitioned 
matrix of the form 


M= k a (2.49) 


where we assume E and H are invertible. One can show (see e.g., [Mur22, Sec 7.3.2] for the proof) 
that 


-1 (M/H)~* —(M/H)-'FH-! 
. = (a dam ei oe (2.50) 


E-!+E-!F(M/E)-!GE-!  —E-!F(M/E)-! 
= ( Oey ne! Mey ) eat) 
where 
M/H £E- FHG (2.52) 
M/E £ H — GE`'F (2.53) 


We say that M/H is the Schur complement of M wrt H, and M/E is the Schur complement of 
M wrt E. 

From the above, we also have the following important result, known as the matrix inversion 
lemma or the Sherman-Morrison-Woodbury formula: 


(M/H) = (E — FH tG) = E! + EF(H — GEF) 'GE~! (2.54) 
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Now we return to the derivation of the MVN conditioning equation. Let us factor the joint p(a1, x2) 
as p(£2)p(xı|£2) as follows: 


1 /xı- u T/&ai Sig) (is H 
1 1 11 12 1 1 
x i 2.55 
p(w, #2) oo 2 E = m & =) E — A | ( ) 


Using the equation for the inverse of a block structured matrix, the above exponent becomes 


ï 
1 (x, — py I 0\ (E/E) 0 
p(z, £2) X exp {-} 2 _ > eo I 0 a (2.56) 
I -553 \ (#1 -m 
x i I as (2.57) 
1 = 2 
= exp { —5(01 — p — ia Bef (22 — pa) (E/22) (2.58) 
-1 1 Ty-l 
(£1 — py — X12%z3 (£2 — H2))} X exp -32 — Pz) Yo (£2 — H2) (2.59) 


This is of the form 
exp(quadratic form in £1, £2) x exp(quadratic form in a2) (2.60) 
Hence we have successfully factorized the joint as 


p(£1, £2) = p(xı|x£2)p(£2) 3 
= N(#1|My\2; 112) (a2| Me, E22) (2.62) 


where 
Hyjo = by + D12D99' (£2 — H2) (2.63) 
Xij = 2/22 £ 5 — D253 X21 (2.64) 


where X/©¥»ə2 is as the Schur complement of © wrt Yoo. 


2.2.5.6 Derivation: information form 


In this section, we derive Equation (2.46) and Equation (2.47) for marginalizing and conditioning an 
MVN in information form. 

First we derive the conditional formula.” Let us partition the information form parameters as 
follows: 


nı Ai Ais 
m e E a) (2.65) 


5. This derivation is due to Giles Harper-Donnelly. 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


We can now write the joint log probabilty of x1, %2 as 


a,\' (A A x a,\" n 
_ [tı 11 12 1 1 1 
In p(z, £2) = (=) te a (=) + a + const. (2.66) 
l + 1 + ia ie 
== gti Aue z g22A2222 = gti Arts = z224 


(2.67) 
+ rN + LIN» + const. 


where the constant term does not depend on a, or £2. 

To calculate the parameters of the conditional distribution p(xı|æ2), we fix the value of £2 and 
collect the terms which are quadratic in x, for the conditional precision and then linear in x, for the 
conditional precision-weighted mean. The terms which are quadratic in a, are just —taj Aix, and 
hence 


Aip = Au (2.68) 
The terms which are linear in x; are 
1 1 
—5@1 Area. _ 3224221 + rN = LAGA = Ai2£2) (2.69) 


since A}, = Aig. Thus the conditional precision-weighted mean is 
nij = M — A1222. (2.70) 


We will now derive the results for marginalizing in information form. The marginal, p(a2), can be 
calculated by integrating the joint, p(a1,x2), with respect to x1: 


p(2) = fo, £2)dx£ı (2.71) 


1 1 1 1 
x [exp f- jatana — 32242282 = 371M2% — 37242181 + xin, + sln dz£ı, 
(2.72) 
where the terms in the exponent have been decomposed into the partitioned structure in Equa- 


tion (2.65) as in Equation (2.67). Next, collecting all the terms involving x1, 


1 1 
p(a2) x exp {-52} Anes + In, | for f- jaAa + r! (M — Arora) } dx, (2.73) 


we can recognise the integrand as an exponential quadratic form. Therefore the integral is equal to 
the normalising constant of a Gaussian with precision, Aj), and precision weighted mean, 7, — Aj2%2, 
which is given by the reciprocal of Equation (2.44). Substituting this in to our equation we have, 


1 1 = 
p(£2) X exp {-52} Anes + zina exp { 3 (™ — Ayo@2)' AG (M; — Aros) (2.74) 
1 1 _ = 
X exp {-52} Anes + abn. + 32241A Mize, = sjan Ain} (2.75) 
_ l r =i T Zi 
= exp -322422 — Agi Ajy A12)£2 + £3 (N2 — A Aji n) p, (2.76) 
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which we now recognise as an exponential quadratic form in #2. Extract the quadratic terms to get 
the marginal precision, 


ASS = A22 — Aoi Ay Arp, (2.77) 
and the linear terms to get the marginal precision-weighted mean, 


n? =m — An Airm. (2.78) 


2.2.6 Linear Gaussian systems 


Consider two random vectors y € R? and z € R4, which are jointly Gaussian with the following 
joint distribution: 


p(z) = N (z| Á, Ž) (2.79) 
plylz) =N (y|Wz +b, 9) (2.80) 


where W is a matrix of size D x L. This is an example of a linear Gaussian system. 


2.2.6.1 Joint distribution 


The corresponding joint distribution, p(z, y) = p(z)p(y|z), is itself a D + L dimensional Gaussian, 
with mean and covariance given by the following (this result can be obtained by moment matching): 


p(z, y) =N (z, y|, £) (2.81a) 
ê j@ a ea (2.81b) 
~ > C A > > WT 
ES e P S 2.81 

(è S) E a Vaal (2.81c) 


See Algorithm 7 for some pseudocode to compute this joint distribution. 


2.2.6.2 Posterior distribution (Bayes rule for Gaussians) 


Now we consider computing the posterior p(z|y) from a linear Gaussian system. Using Equation (2.39) 
for conditioning a joint Gaussian, we find that the posterior is given by 


p(z|y) = N (z| ñ, £) (2.82a) 
n =ñ +S W'(24+W ŽW)! (y -— (W ñ +b)) (2.82b) 
£ =% — Žž wQ + WEwWw') ws (2.82c) 


This is known as Bayes’ rule for Gaussians. We see that if the prior p(z) is Gaussian, and the 
likelihood p(y|z) is Gaussian, then the posterior p(z|y) is also Gaussian. We therefore say that the 
Gaussian prior is a conjugate prior for the Gaussian likelihood, since the posterior distribution has 
the same type as the prior. (In other words, Gaussians are closed under Bayesian updating.) 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


We can simplify these equations by defining S = W ¥ WT + Q, C =% WT, and m= W jt +b, 
as in Equation (2.81). We also define the Kalman gain matrix:° 


K=cs"! (2.83) 


From this, we get the posterior 


n =p +K(y - m) (2.84) 

$ =% -KC (2.85) 
Note that 

KSK' = cs“!ss-'c' = cs-‘c' = KC! (2.86) 


and hence we can also write the posterior covariance as 
£ =% -KSK' (2.87) 


Using the matrix inversion lemma from Equation (2.54), we can also rewrite the posterior in the 
following form [Bis06, p93], which takes O(L°) time instead of O(D?) time: 


S=(H 0 +wa-'w)- (2.88) 
a ~-1 
n= [WQ (y—b)+>d ñ] (2.89) 


Finally, note that the corresponding normalization constant for the posterior is just the marginal 
on y evaluated at the observed value: 


v) = | Nl HEN wlWe +b, 9)dz 
= N (y|W fi +b, 92 + W EW) = N (y|m, S) (2.90) 


From this, we can easily compute the log marginal likelihood. We summarize all these equations in 
Algorithm 7. 


2.2.6.3 Example: Sensor fusion with known measurement noise 


Suppose we have an unknown quantity of interest, z ~ N (u, z), from which we get two noisy 
measurements, « ~ N(z,,) and y ~ N(z,™,). Pictorially, we can represent this example as 
x + z —> y. This is an example of a linear Gaussian system. Our goal is to combine the evidence 
together, to compute p(z|x;0). This is known as sensor fusion. (In this section, we assume 
0 = (£z, &Xy) is known. See Supplementary Section 2.1.2 for the general case.) 

We can combine x and y into a single vector v, so the model can be represented as z + v, where 
p(v|z) = N(v|Wz, £), where W = [0, I; 0, I] and ©, = [%,,0;0, 4] are block-structured matrices. 
We can then apply Bayes’ rule for Gaussians (Section 2.2.6.2) to compute p(z|v). 


6. The name comes from the Kalman filter algorithm, which we discuss in Section 8.3.2. 
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Figure 2.8: We observe æ = (0, —1) (red cross) and y = (1,0) (green cross) and estimate E|z|x, y, 0] (black 
cross). (a) Equally reliable sensors, so the posterior mean estimate is in between the two circles. (b) Sensor 
2 is more reliable, so the estimate shifts more towards the green circle. (c) Sensor 1 is more reliable in 
the vertical direction, Sensor 2 is more reliable in the horizontal direction. The estimate is an appropriate 
combination of the two measurements. Generated by sensor_ fusion_2d.ipynb. 


Figure 2.8(a) gives a 2d example, where we set X, = Uy = 0.0112, so both sensors are equally 
reliable. In this case, the posterior mean is halfway between the two observations, x and y. In 
Figure 2.8(b), we set © = 0.0512 and ©, = 0.0112, so sensor 2 is more reliable than sensor 1. In 
this case, the posterior mean is closer to y. In Figure 2.8(c), we set 


10 1 1 1 
=, = 001 (1 ae £, = 0.01 G a (2.91) 


so sensor 1 is more reliable in the second component (vertical direction), and sensor 2 is more reliable 
in the first component (horizontal direction). In this case, the posterior mean uses x’s vertical 
component and y’s horizontal component. 


2.2.7 A general calculus for linear Gaussian systems 


In this section, we discuss a general method for performing inference in linear Gaussian systems. 
The key is to define joint distributions over the relevant variables in terms of a potential function, 
represented in information form. We can then easily derive rules for marginalizing potentials, 
multiplying and dividing potentials, and conditioning them on observations. Once we have defined 
these operations, we can use them inside of the belief propagation algorithm (Section 9.2) or junction 
tree algorithm (Supplementary Section 9.2) to compute quantities of interest. We give the details on 
how to perform these operations below; our presentation is based on [Lau92; Mur02]. 


2.2.7.1 Moment and canonical parameterization 


We can represent a Gaussian distribution in moment form or in canonical (information) form. In 
moment form we have 


(arp, WB) = px exp (—3(@- TE e- n) (2.92) 


where p = (2r)~"/2|¥|~2 is the normalizing constant that ensures J, 0(@; p, u, ©) = 1. (n is the 
dimensionality of æ.) Expanding out the quadratic form and collecting terms we get the canonical 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


form: 
1 1 

(x; g, h, K) = exp (o +ath— e"ke) = exp | + 2 hizi — 5 2 2 Kn) (2.93) 
where 

K=>"! (2.94) 

h=>"'p 

1 
g = logp— 5u' Kyu (2.96) 


K is often called the precision matrix. 
Note that potentials need not be probability distributions, and need not be normalizable (integrate 
to 1). We keep track of the constant terms (p or g) so we can compute the likelihood of the evidence. 


2.2.7.2 Multiplication and division 


We can define multiplication and division in the Gaussian case by using canonical forms, as follows. 
To multiply ¢1(a1,...,%%3 91,1, K1) by do(@e41,---,2nj 92, h2, K2), we extend them both to the 
same domain £1,...,%, by adding zeros to the appropriate dimensions, and then computing 


(91, h1, K1) * (92, h2, K2) = (g1 + 92, hı + h2, Ki + Kə) (2.97) 
Division is defined as follows: 
(g1, hi, K1)/(92, h2, Ke) = (g1 — 92, hı — h2, Kı — Ke) (2.98) 


2.2.7.3 Marginalization 


Let w be a potential over a set W of variables. We can compute the potential over a subset V C W 
of variables by marginalizing, denoted dy = wv ow. Let 


Ly hy Ki, Ky 
meee. Sele) Beles ne) ea) 


with xı having dimension nı and x2 having dimension nə. It can be shown that 


J oler.aig.h.K) = 0læzi d, h K) (2.100) 

zı 

where 
ĝ=g+ ; (nı log(2m) — log |K11| + hI Kī hi) (2.101) 
h = hə — Ko Kīřhı (2.102) 
K = Ky — Ka Ki Kı? (2.103) 
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2.2.7.4 Conditioning on evidence 


Consider a potential defined on (a, y). Suppose we observe the value y. The new potential is given 
by the following reduced dimensionality object: 


h 1 K K x 
* — T T xX) * (pT T XX XY 
o* (æ) = exp[g+ (£? y”) (R) 5 (er? y”) ee a (z) (2.104) 
1 1 
= exol(g +hyy— su"Kyvu) +a" (hx —Kxyy) - z% Kxxa] (2.105) 
This generalizes the corresponding equation in [Lau92] to the vector-valued case. 


2.2.7.5 Converting a linear-Gaussian CPD to a canonical potential 


Finally we discuss how to create the initial potentials, assuming we start with a directed Gaussian 
graphical model. In particular, consider a node with a linear-Gaussian CPD: 


pala) = cexp |-3 ((@- 4 BTu)"=""(@— p - Bu)! (2.106) 
= exp -5 (x u) e Sn & (2.107) 
+(e u) eae = luau + log (| (2.108) 


where c = (2)~"/2|S|~-2. Hence we set the canonical parameters to 


1 _ n 1 
g= sue lu— 5 log(2m) — 3 log |£] (2.109) 
E'u 
= (pe (2.110) 
yo -Op I z 
a es aan") O (43) ey oe 


In the special case that x is a scalar, the corresponding result can be found in [Lau92]. In particular 
we have Xt = 1/0? , B = b and n = 1, so the above becomes 


=F 1 2 
i= 52 — z lelro ) (2.112) 
ha 5 & (2.113) 
1/1 bT 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


2.2.7.6 Example: Product of Gaussians 


As an application of the above results, we can derive the (unnormalized) product of two Gaussians, 
as follows (see also [Kaal2, Sec 8.1.8]): 


N (ælu, 21) x N(æ|uz, £2) x N (ælu, Bs) (2.115) 
where 

D3 = (511 + 57) (2.116) 

bs = E3(By "py + Uy Ho) (2.117) 


We see that the posterior precision is a sum of the individual precisions, and the posterior mean 
is a precision-weighted combination of the individual means. We can also rewrite the result in the 
following way, which only requires one matrix inversion: 
X; = X (X1 + X2) 152 (2.118) 
Mz = Xo(X1 + Ya) * py + X1 (£1 + Ba) "pe (2.119) 


In the scalar case, this becomes 


2 2 22 
03 t M20, 0102 
N (zm, 0?)N (zlu2, 03) x N ( z= “ 
(2]H1, 07) N (2]H2,09) x N | a| o +o o +o? ) 


2.2.8 Some other multivariate continuous distributions 


In this section, we summarize some other widely used multivariate continuous distributions. 


2.2.8.1 Dirichlet distribution 


A multivariate generalization of the beta distribution is the Dirichlet’ distribution, which has 
support over the probability simplex, defined by 


K 
Sk ={x£:0 <a, <1, > t= 1} (2.121) 
k=1 


The pdf is defined as follows: 


Dir(z|a) San = Is ek (x € Sx) (2.122) 


where B(a) is the multivariate beta function, 


Bla) # Hilen 


2.123 
ine pam 1%) 
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Figure 2.9: (a) The Dirichlet distribution when K = 3 defines a distribution over the simplex, which can be 
represented by the triangular surface. Points on this surface satisfy O < 0. < 1 and ae 0. =1. Generated 
by dirichlet_ 38d_triangle_plot.ipynb. (b) Plot of the Dirichlet density for œ = (20, 20,20). (c) Plot of the 
Dirichlet density for a = (3,3,20). (d) Plot of the Dirichlet density for a = (0.1,0.1,0.1). Generated by 
dirichlet_ 3d_spiky_plot.ipynb. 


Figure 2.9 shows some plots of the Dirichlet when K = 3. We see that ao = } |, ax controls the 
strength of the distribution (how peaked it is), and the a; control where the peak occurs. For example, 
Dir(1,1,1) is a uniform distribution, Dir(2, 2,2) is a broad distribution centered at (1/3, 1/3, 1/3), 
and Dir(20, 20, 20) is a narrow distribution centered at (1/3, 1/3, 1/3). Dir(3, 3, 20) is an asymmetric 
distribution that puts more density in one of the corners. If ax < 1 for all k, we get “spikes” at 
the corners of the simplex. Samples from the distribution when a; < 1 will be sparse, as shown in 
Figure 2.10. 


7. Johann Dirichlet was a German mathematician, 1805-1859. 
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1 
5 Samples from Dir (alpha=0.1) Samples from Dir (alpha=1.0) 
2 107 105 
3 i l E jio ü ē M“ 7+ oy l 
= i 1 2 3 4 5 ii 1 2 3 4 5 
4 2J 07 
: O5 1 m r r r = he r a, 
2 1 2 3 4 5 1 2 3 4 5 
1.05 1.05 
6 a T T T T 03 i=== — STOO 
7 TE 1 2 3 4 5 as: 1 2 3 4 5 
3 8 1 7 T 7 7 aa b ____ 7 E 7 
= 1 2 3 4 5 1 2 3 4 5 
104 105 
° b i M l l niee  —__ =e., 
10 1 2 3 4 5 1 2 3 4 5 
uv (a) (b) 
12 
13 Figure 2.10: Samples from a 5-dimensional symmetric Dirichlet distribution for different parameter values. 
14 (a) a = (0.1,...,0.1). This results in very sparse distributions, with many Os. (b) œ = (1,...,1). This 
15 results in more uniform (and dense) distributions. Generated by dirichlet_samples_ plot. ipynb. 
16 
17 
18 For future reference, here are some useful properties of the Dirichlet distribution: 
19 
20 Qk Ap —1 ar (Qo — Ar) 
E [zk] = —, mode [zk] = ———, V [zrk] = —-—— 2.124 
a Ele] = SS, mode [zi] = S85, Vie) = Sah (2.124) 
22 
23 where ao = >>), Qk. 
24 Often we use a symmetric Dirichlet prior of the form a, = a/K. In this case, we have E [x,] = 1/K, 
25 and V [rx] = Kies So we see that increasing a increases the precision (decreases the variance) of 
26 the distribution. 
27 
28 š š Sous : 
x 2.2.8.2 Multivariate Student distribution 
30 One problem with Gaussians is that they are sensitive to outliers. Fortunately, we can easily extend 
31 the Student distribution, discussed in Main Section 2.2.2.3, to D dimensions. In particular, the pdf 
32 ofthe multivariate Student distribution is given by 
33 
34 1 1 i —) 
a Teale, E) => [1+ 3e- aE e- n) (2.125) 
36 
m T(v/2 pP/2_D/2 
aT Z= (v/2) (2.126) 
38 r(v/2+ D/2) |21? 
39 
4o Where & is called the scale matrix. 
ñ The Student has fatter tails than a Gaussian. The smaller v is, the fatter the tails. As v + oo, 
42 the distribution tends towards a Gaussian. The distribution has these properties: 
43 v 
44 mean = u, mode = j4,cov = oa z% (2.127) 
45 
46 The mean is only well defined (finite) if v > 1. Similarly, the covariance is only well defined if v > 2. 
47 
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2.2.8.3 Circular normal (von Mises Fisher) distribution 


Sometimes data lives on the unit sphere, rather than being any point in Euclidean space. For example, 

any D dimensional vector that is £:-normalized lives on the unit (D — 1) sphere embedded in RP. 
There is an extension of the Gaussian distribution that is suitable for such angular data, known as 

the von Mises-Fisher distribution, or the circular normal distribution. It has the following pdf: 


1 
vMF(z|u, K) = 7 exp(ku' £) (2.128) 
(21)? /? Ip 2-1 (k) 
pan (2.129) 
where p is the mean (with |||| = 1), « > 0 is the concentration or precision parameter (analogous 


to 1/o for a standard Gaussian), and Z is the normalization constant, with 7„(-) being the modified 
Bessel function of the first kind and order r. The vMF is like a spherical multivariate Gaussian, 
parameterized by cosine distance instead of Euclidean distance. 

The vMF distribution can be used inside of a mixture model to cluster @:-normalized vectors, as 
an alternative to using a Gaussian mixture model [Ban-+05]. If x — 0, this reduces to the spherical 
K-means algorithm. It can also be used inside of an admixture model (Main Section 28.4.2); this is 
called the spherical topic model [Rei+10]. 

If D = 2, an alternative is to use the von Mises distribution on the unit circle, which has the form 


vMF(z|u, K) = 3 exp(K cos(x — u)) (2.130) 
Z = 2nlo() (2.131) 


2.2.8.4 Matrix-normal distribution (MN) 
The matrix normal distribution is defined by the following probability density function over matrices 
X €R”*?: 
MN (X|M, U, V) + i. aplati (X —M)'U-!(K — M)V] (2.132) 
NT mU PL 


where M € R”*? is the mean value of X, U € S}%” is the covariance among rows, and V € ae is 
the precision among columns. It can be seen that 


vec(X) ~ N(vec(M), V~! @ U). (2.133) 


Note that there is another version of the definition of the matrix normal distribution using the 
column-covariance matrix V = V~! instead of V, which leads to the density 


1 
Inrp/2|U|p/2\V|r/2 Xp 


{ at es mux- mý]. (2.134) 


These two versions of definition are obviously equivalent, but we will see that the definition we 
adopt in Equation (2.132) will leads to a neat update of the posterior distribution (just as the 
precision matrix is more convenient to use than the covariance matrix in analyzing the posterior of 
the multivariate normal distribution with a conjugate prior). 
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2.2. SOME COMMON PROBABILITY DISTRIBUTIONS 


2.2.8.5 Wishart distribution 


The Wishart distribution is the generalization of the Gamma distribution to positive definite matrices. 
Press [Pre05, p107] has said “The Wishart distribution ranks next to the Normal distribution in 
order of importance and usefulness in multivariate statistics”. We will mostly use it to model our 
uncertainty when estimating covariance matrices (see Section 3.3). 


The pdf of the Wishart is defined as follows: 
1 1 
Wi(d|S, v) ê ger exp (-;7@s)) (2.135) 


Z Ê |S 7/22 P/2T p(v/2) (2.136) 


Here v is called the “degrees of freedom” and S is the “scale matrix”. (We shall get more intuition for 
these parameters shortly.) The normalization constant only exists (and hence the pdf is only well 
defined) if v > D — 1. 

The distribution has these properties: 


mean = vS, mode = (v — D — 1)S (2.137) 


Note that the mode only exists if v > D +1. 
If D = 1, the Wishart reduces to the Gamma distribution: 


1 
Wi(A|s7!, v) = Ga(A|shape = 5 rate =>) (2.138) 
s 
If s = 2, this reduces to the chi-squared distribution. 
There is an interesting connection between the Wishart distribution and the Gaussian. In particular, 
let £n ~ N (0, ©). One can show that the scatter matrix, S = DoR £næl , has a Wishart distribution: 
S ~ Wi(d, N). 


2.2.8.6 Inverse Wishart distribution 


If A ~ Ga(a, b), then that + ~ 1G(a,b). Similarly, if 57t ~ Wi(S~!,v) then © ~ IW(S,v + D +1), 
where IW is the inverse Wishart, the multidimensional generalization of the inverse Gamma. It is 
defined as follows, for v > D — 1 and S > 0: 


1 1 
IW(S|S,v) = melee exp (-518=") (2.139) 
Zw = |S "PPPT p(w/2) (2.140) 
One can show that the distribution has these properties 
mean = A mode = SEa (2.141) 
v—D-1 v+D+1 
If D = 1, this reduces to the inverse Gamma: 
IW(o?|s~', v) = IG(a?|v/2, s/2) (2.142) 


If s = 1, this reduces to the inverse chi-squared distribution. 
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2.3 The exponential family 


In this section, we define the exponential family, which includes many common probability 
distributions. The exponential family plays a crucial role in statistics and machine learning, for 
various reasons, including the following: 


e The exponential family is the unique family of distributions that has maximum entropy (and 
hence makes the least set of assumptions) subject to some user-chosen constraints, as discussed in 
Section 2.3.7. 


e The exponential family is at the core of GLMs, as discussed in Section 15.1. 
e The exponential family is at the core of variational inference, as discussed in Chapter 10. 


e Under certain regularity conditions, the exponential family is the only family of distributions with 
finite-sized sufficient statistics, as discussed in Section 2.3.5. 


e All members of the exponential family have a conjugate prior [DY79], which simplifies Bayesian 
inference of the parameters, as discussed in Section 3.2. 


2.3.1 Definition 


Consider a family of probability distributions parameterized by n € R with fixed support over 
XP C RP. We say that the distribution p(a|7) is in the exponential family if its density can be 
written in the following way: 


A 


h(x) exp[n"T (a)] = h(x) exp[n" T (æ) — A(n)] (2.143) 


p(x\n) Zn) 


where h(x) is a scaling constant (also known as the base measure, often 1), T(x) € R* are 
the sufficient statistics, 7 are the natural parameters or canonical parameters, Z(1) is 
a normalization constant known as the partition function, and A(7) = log Z(ņ) is the log 
partition function. In Section 2.3.3, we show that A is a convex function over the convex set 
Q 4 {n € R¥ : A(n) < oo}. 

It is convenient if the natural parameters are independent of each other. Formally, we say that an 
exponential family is minimal if there is no n € R* \ {0} such that n' T (æ) = 0. This last condition 
can be violated in the case of multinomial distributions, because of the sum to one constraint on 
the parameters; however, it is easy to reparameterize the distribution using K — 1 independent 
parameters, as we show below. 

Equation (2.143) can be generalized by defining 7 = f(@), where ¢@ is some other, possibly smaller, 
set of parameters. In this case, the distribution has the form 


p(a|b) = h(x) exp[f($)'T (x) — AFCA) (2.144) 


If the mapping from @ to 77 is nonlinear, we call this a curved exponential family. If n = f(@) = ¢, 
the model is said to be in canonical form. If, in addition, T(x) = x, we say this is a natural 
exponential family or NEF. In this case, it can be written as 


plæln) = h(«) exp[n"e — A(n)] (2.145) 
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2.3. THE EXPONENTIAL FAMILY 


We define the moment parameters as the mean of the sufficient statistics vector: 


m=E([T(2)| (2.146) 


We will see some examples below. 


2.3.2 Examples 


In this section, we consider some common examples of distributions in the exponential family. Each 
corresponds to a different way of defining h(a) and T(a) (since Z and hence A is derived from 
knowing h and 7). 

2.3.2.1 Bernoulli distribution 


The Bernoulli distribution can be written in exponential family form as follows: 


Ber(z|u) = p? (1 — p)? (2.147) 
= expļz log(u) + (1 — x) log(1 — p)] (2.148) 
= exp[T (z)'n] (2.149) 


where 7 (x) = [I (x = 1), I (x = 0)], n = [og(u),log(1 — y)], and u is the mean parameter. However, 
this is an over-complete representation since there is a linear dependendence between the features. 
We can see this as follows: 


1'T (2) =1(@# =0) +1 (@# =1) =1 (2.150) 
If the representation is overcomplete, 7 is not uniquely identifiable. It is common to use a minimal 


representation, which means there is a unique 77 associated with the distribution. In this case, we 
can just define 


Ber(x|u) = exp fe log (=) + log(1 — m) (2.151) 


We can put this into exponential family form by defining 


n = log (=) (2.152) 
Pia) =a (2.153) 
A(n) = —log(1 — u) = log(1 + e”) (2.154) 
h(a) = 1 (2.155) 


We can recover the mean parameter u from the canonical parameter 7 using 


1 


ESS 


which we recognize as the logistic (sigmoid) function. 
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2.3.2.2 Categorical distribution 


We can represent the discrete distribution with K categories as follows (where x; = I (x = k)): 


K K 
Cat(x|u) = Lv = exp b Tp log m (2.157) 
k=1 k=1 
K- K-i K-1 
= exp y Xp log Uk + (: — 5 z) log(1 — m) (2.158) 
= k=1 k=1 
K- K-1 
= exp bs Tp log (=) + log(1 — m) (2.159) 
k= r= X j= Hj k=1 
Ko 
= exp 5 Tp log (= E) + log uK (2.160) 
ka HK 
where ug = 1— Da uk. We can write this in exponential family form as follows: 
Cat(aln) = exp(n"T(@) — A(n)) (2.161) 
n = flog ,... ,log EH (2.162) 
H HK 
A(n) = — log(ux) (2.163) 
T(æ)= Ez =1),.. n I(@=K-1) (2.164) 
h(x)=1 (2.165) 
We can recover the mean parameters from the canonical parameters using 
aa (2.166) 
Uk = —— i 
a +E =1 * eni 
If we define ng = 0, we can rewrite this as follows: 
elk 
Hk = K (2.167) 


pee eni 


for k=1: K. Hence u = softmax(n), where softmax is the softmax or multinomial logit function in 
Equation (15.131). From this, we find 


ba €™ 1 
(2.168) 
1+ em “1 q ei 


uK =1 
and hence 
A(n) = — log(ux) = log (>: e) 
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2.3. THE EXPONENTIAL FAMILY 


2.3.2.3 Univariate Gaussian 


The univariate Gaussian is usually written as follows: 


N (zlu, 0°) = r expl- z(e- 1)" 
(2ro?)3 2 
1 i ines is 
= E exp[-z? 552° ~ G2! log o] 


We can put this in exponential family form by defining 


x 
2 2 
H =] 
Aln) = 555 + logo = T — 3 08(—2n2) 
1 


The moment parameters are 


m= |u, p? +07] 


2.3.2.4 Univariate Gaussian with fixed variance 


If we fix o? = 1, we can write the Gaussian as a natural exponential family, by defining 


= 
T(a)=2 
2 2 
p H 
Alu) = z t+ logo = > 


h(x) = = exp|- 7] = N(el0, 1) 


Note that this in example, the base measure h(x) is not constant. 
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2.3.2.5 Multivariate Gaussian 


It is common to parameterize the multivariate normal (MVN) in terms of the mean vector ps and the 


covariance matrix X. The corresponding pdf is given by 


1 1 1 
N (ax|p, £) = ex (-seTE te at =e ) 
(|p, ©) Qn) Jaan P Ua MSH u 


1 
= cexp (a= — TEE 


a pin En) 
~ (Qn)P/2,/det(S) 


c 


(2.181) 
(2.182) 


(2.183) 


However, we can also represent the Gaussian using canonical parameters or natural parame- 


ters, also called the information form: 


Me Sr 
=X pe 
N.(al€, A) = c exp GE — Jeha) 


g SPEA te) 
(21) P/2,/det(A~*) 


(2.184) 
(2.185) 


(2.186) 


(2.187) 


where we use the notation \V.() to distinguish it from the standard parameterization V(). Here A is 


called the precision matrix and € is the precision-weighted mean vector. 
We can convert this to exponential family notation as follows: 


N.(al€, A) = (27)~?/? exp 5 og | - zeae] exp -; Tae +076 
Ce ee 2 2 2 


h(a) 


g(n) 


= h(a)g(n) exp [-527Ae + ae 
= h(a)g(n) exp -S Li£jAij) + x'e 
tj 


= h(x)g(n) exp | —5vee(A)Tvee(a™) + ag 


= h(a) exp [n" T(x) — A(n)] 
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(2.191) 


(2.192) 
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2.3. THE EXPONENTIAL FAMILY 


where 
h(a) =(20)-?? (2.193) 
n = [€;—5vec(A)] = [E7* p; —Fvee(“) (2.194) 
T(x) = [2; vec(xa")] (2.195) 
A(n) = — log g(n) = = log |A| + SEAE (2.196) 


From this, we see that the mean (moment) parameters are given by 


m =E [T(æ)] = [p; uu" + £] (2.197) 


(Note that the above is not a minimal representation, since A is a symmetric matrix. We can convert 
to minimal form by working with the upper or lower half of each matrix.) 


2.3.2.6 Non-examples 


Not all distributions of interest belong to the exponential family. For example, the Student distribution 
(Section 2.2.2.3) does not belong, since its pdf (Equation (2.13)) does not have the required form. 
(However, there is a generalization, known as the -exponential family |Nau04; Tsa88] which does 
include the Student distribution.) 

As a more subtle example, consider the uniform distribution, Y ~ Unif(041, 02). The pdf has the 
form 


= —— I (0i <y < b) (2.198) 


It is tempting to think this is in the exponential family, with h(y) = 1, T (y) = 0, and Z(0) = b2 — 81. 
However, the support of this distribution (i.e., the set of values Y = {y : p(y) > 0} depends on the 
parameters 0, which violates an assumption of the exponential family. 


2.3.3 Log partition function is cumulant generating function 


The first and second cumulants of a distribution are its mean E [X] and variance V [X], whereas the 
first and second moments are E [X] and E | X?]. We can also compute higher order cumulants (and 
moments). An important property of the exponential family is that derivatives of the log partition 
function can be used to generate all the cumulants of the sufficient statistics In particular, the first 
and second cumulants are given by 


VnA(n) = E[T(a)| (2.199) 
V7,A(n) = Cov [T(2)] (2.200) 


We prove this result below. 
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2.3.3.1 Derivation of the mean 


For simplicity, we focus on the 1d case. For the first derivative we have 


m 7 5 (ioe / exp(yT(a))M(e)ae) (2.201) 
a 

fire om 

E L i a Gw 

= J TET) — A(n))h(x)dx (2.204) 

= | T(eple)ae =E (T) (2.205) 


For example, consider the Bernoulli distribution. We have A(n) = log(1 + e”), so the mean is given 
by 


dA _ Ciel 1 
dn 1+e”7 1+e7 


=o(n) =n (2.206) 


2.3.3.2 Derivation of the variance 


For simplicity, we focus on the 1d case. For the second derivative we have 


PA 


Ta =p | Twn le) - Ahla)de (2.207) 
= | Tæ) exp Tæ) - A) AET (E) - A’) ae (2.208) 
= | TEPATE) - Ade (2.209) 
= | T%(op(o)ae - A'n) | T@ple)de (2.210) 
=E[T?(X)] - E [T (2)? = V[T(a)] (2.211) 


where we used the fact that A’ (n) = $ = E [T (x)|. For example, for the Bernoulli distribution we 
have 


dA Ea a Geet) te (2.212) 


e: 1 1 1 


ae Ea ae ee eee (2.213) 
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2.3. THE EXPONENTIAL FAMILY 


2.3.3.3 Connection with the Fisher information matrix 


In Section 2.4, we show that, under some regularity conditions, the Fisher information matrix is 
given by 


F(n) © E,(ajn) [V log p(x\n) V log p(w\n)"] = —Epeiny [V3 log p(x|n)| (2.214) 
Hence for an exponential family model we have 
F (1) = —Ep(e|n) [Vi,(n"T (2) — A(n))] = Vz,A(n) = Cov [T(a)] (2.215) 


Thus the Hessian of the log partition function is the same as the FIM, which is the same as the 
covariance of the sufficient statistics. See Section 2.4.5 for details. 


2.3.4 Canonical (natural) vs mean (moment) parameters 
Let Q be the set of normalizable natural parameters: 
Q4 {n ERE : Z(n) < oœ} (2.216) 


We say that an exponential family is regular if Q is an open set. It can be shown that 2 is a convex 
set, and A(7) is a convex function defined over this set. 

In Section 2.3.3, we prove that the derivative of the log partition function is equal to the mean of 
the sufficient statistics, i.e., 


m =E[T(x)] =V,A(n) (2.217) 


The set of valid moment parameters is given by 


M= {m ER" : E [T(a)] = m} (2.218) 


for some distribution p. 
We have seen that we can convert from the natural parameters to the moment parameters using 


m = Vy A(n) (2.219) 
If the family is minimal, one can show that 

nN = VmA* (m) (2.220) 
where A*(m) is the convex conjugate of A: 


A*(m) ê sup u'n — A(n) (2.221) 
NEN 


Thus the pair of operators (VA, VA*) lets us go back and forth between the natural parameters 
7 € Q and the mean parameters m € M. 

For future reference, note that the Bregman divergences (Section 5.1.8) associated with A and A* 
are as follows: 


Ba(Aa||Az) = A(A1) — A(Az) — (Ar — Az)’ Va A(A2) (2.222) 
Ba» (My ||M2) = Alnı) — Alu) — (Hy Mz)" Vp Alho) (2.223) 
(2.224) 
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2.3.5 MULE for the exponential family 


The likelihood of an exponential family model has the form 


Np 
p(D|n) = iig len ] ax [> TAR Nat) oc exp [n'T(D) - NA(n)] (2.225) 


n=1 


where 7 (D) are the sufficient statistics: 


=n Bais SET Lp) (2.226) 


For example, for the Bernoulli model we have T(D) = [>°,, 1 (an =1)], and for the univariate 
Gaussian, we have T(D) = [> £n, +p £2]. 

The Pitman-Koopman-Darmois theorem states that, under certain regularity conditions, the 
exponential family is the only family of distributions with finite sufficient statistics. (Here, finite 
means a size independent of the size of the data set.) In other words, for an exponential family with 
natural parameters 7, we have 


p(D\n) = p(T (D)\n) (2.227) 


We now show how to use this result to compute the MLE. The log likelihood is given by 
log p(D|n) = n'T(D) — NpA(n) + const (2.228) 


Since — A(n) is concave in n, and n'T(D) is linear in 7, we see that the log likelihood is concave, and 
hence has a unique global maximum. To derive this maximum, we use the fact (shown in Section 2.3.3) 
that the derivative of the log partition function yields the expected value of the sufficient statistic 
vector: 


Vn log p(D|n) = Ynn” T (D) — NpVnA(n) = T(D) - NoE[T(a)] (2.229) 


For a single data case, this becomes 


Vn log p(æ|n) = T (x) — E [T (æ)] (2.230) 


Setting the gradient in Equation (2.229) to zero, we see that at the MLE, the empirical average of 
the sufficient statistics must equal the model’s theoretical expected sufficient statistics, i.e., 7 must 
satisfy 


ere aS Fen) (2.231) 


D n=1 


This is called moment matching. For example, in the Bernoulli distribution, we have T(z) = 
I(X = 1), so the MLE satisfies 


[TE] = p(X =1) = p= -+Y Ien =1) (2.232) 
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2.3. THE EXPONENTIAL FAMILY 


2.3.6 Exponential dispersion family 


In this section, we consider a slight extension of the natural exponential family known as the 
exponential dispersion family. This will be useful when we discuss GLMs in Section 15.1. For a 
scalar variable, this has the form 


ng — #2] 


p(z|n, 07) = h(x, 07) exp | = (2.233) 


2 


Here o? is called the dispersion parameter. For fixed o?, this is a natural exponential family. 


2.3.7 Maximum entropy derivation of the exponential family 


Suppose we want to find a distribution p(a) to describe some data, where all we know are the 
expected values (Fķ) of certain features or functions f;,(a): 


| iraha) = Fy (2.234) 


For example, fı might compute x, f2 might compute x”, making F, the empirical mean and F> the 
empirical second moment. Our prior belief in the distribution is q(x). 

To formalize what we mean by “least number of assumptions”, we will search for the distribution 
that is as close as possible to our prior q(x), in the sense of KL divergence (Section 5.1), while 
satisfying our constraints. 

If we use a uniform prior, g(x) « 1, minimizing the KL divergence is equivalent to maximizing the 
entropy (Section 5.2). The result is called a maximum entropy model. 

To minimize KL subject to the constraints in Equation (2.234), and the constraint that p(x) > 0 
and $`, p(a) = 1, we need to use Lagrange multipliers. The Lagrangian is given by 


J(p, A) = -Ý ple) log a + Xo (: = Zro) +o (z 7 Eron] (2.235) 
x x k x 


We can use the calculus of variations to take derivatives wrt the function p, but we will adopt a 
simpler approach and treat p as a fixed length vector (since we are assuming that æ is discrete). 
Then we have 


OF oe a = ) 3% V nfl = 0) (2.236) 
q(x =c) 7 


Setting = = 0 for each c yields 


pa) = © exp P ao- 2 Afele ) (2.237) 


where we have defined Z £ ett, Using the sum-to-one constraint, we have 


l=) p(z) = Yaa) x) exp (-y So Ne fe (@ ) (2.238) 


T 
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Hence the normalization constant is given by 


Z =X q(x) exp (- 5 fle) (2.239) 
x k 


This has exactly the form of the exponential family, where f(x) is the vector of sufficient statistics, 
—X are the natural parameters, and g(a) is our base measure. 

For example, if the features are fı(x) = x and f2(x) = x”, and we want to match the first and 
second moments, we get the Gaussian disribution. 


2.4 Fisher information matrix (FIM) 


In this section, we discuss an important quantity called the Fisher information matrix, which 
is related to the curvature of the log likelihood function. This has many applications, such as 
characterizing the asymptotic sampling distribution of the MLE, deriving Jeffreys’ uninformative 
priors (Section 3.6.2) and in natural gradient descent. 
2.4.1 Definition 
The score function is defined to be the gradient of the log likelihood: 

s(0) = Vlog p(2|6) (2.240) 


The Fisher information matrix (FIM) is defined to be the covariance of the score function: 


F(0) = Ex~p(aja) |V log p(x|0)V log p(x|6)"| (2.241) 


so the (i, 7)’th entry has the form 


Fy = Bove | (5, loa (a6) & loz r(2\0) ) | (2.242) 


We give an interpretation of this quantity below. 


2.4.2 Equivalence between the FIM and the Hessian of the NLL 


In this section, we prove that the Fisher information matrix equals the expected Hessian of the 
negative log likelihood (NLL) 


NLL(@) = — log p(D]6) (2.243) 


Since the Hessian measures the curvature of the likelihood, we see that the FIM tells us how well 
the likelihood function can identify the best set of parameters. (If a likelihood function is flat, we 
cannot infer anything about the parameters, but if it is a delta function at a single point, the best 
parameter vector will be uniquely determined.) Thus the FIM is intimately related to the frequentist 
notion of uncertainty of the MLE, which is captured by the variance we expect to see in the MLE if 
we were to compute it on multiple different datasets drawn from our model. 

More precisely, we have the following theorem. 
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i 
2 Theorem 2.4.1. If log p(x|0) is twice differentiable, and under certain regularity conditions, the 
3 FIM is equal to the expected Hessian of the NLL, i.e., 
; F(0);; +E 2 30, ° (x0) 2 50, °° (x|0) a 50,0, °° (a0) (2.244) 
6 
7 Before we prove this o we on the following ek T 
8 Lemma 2.4.1. The expected value of the score function is zero, i.e., 
9 a~ 
ig `p(æl6) [V log p(x|@)] = 0 (2.245) 
11 We will prove this lemma in the scalar case. First, note that since f p(x|0)dx = 1, we have 
12 o 
13 g [red (2.246) 
14 
T Combining this with the identity 
16 a _ fa 
BS Zpen = | 5 lowr(ela)| pel (2.247) 
+8 we have 
19 
20 o= [op (2|0)da =| E ; loa(el6) p(a|@)dxz = E [s(0)] (2.248) 
21 
22 Now we return to the proof of our main theorem. For simplicity, we will focus on the scalar case, 
23 following the presentation of [Ric95, p263]. 
r Proof. Taking derivatives of Equation (2.248), we have 
26 oð ð 
m 0= 5S E tog r(x) p(x|0)dzx (2.249) 
28 fag o o 

= =s l 0 0 = l —p(zx|0)d 2.2 
a = | [ipren] peoa | |Z oere] Sotalo)ae (2.250) 
a 0? ð i 
8 = | Sa log pelo) palide J E < log plio) p(al6)dex (2.251) 
33 and hence 
34 2 
35 —E,~6 - log p(a|0)| = E 2 log p(x|0) (2.252) 
36 062 00 
37 as claimed. 
38 
39 Now consider the Hessian of the NLL given N iid samples D = {a, :n=1: N}: 
7 H; ê- 2 log p(D|@) = San o log p(an|0) (2.253) 
1 g 30:0; 6; 00,0; 6; 
43 From the above theorem, we i 
44 
45 'p(D\@) [H(P)|a] = NF (0) (2.254) 
46 We will use this result later in this chapter. 
47 
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2.4.3 Examples 


In this section, we give some simple examples of how to compute the FIM. 


2.4.3.1 FIM for the Binomial 
Suppose x ~ Bin(n, 0). The log likelihood for a single sample is 
L(0|x) = xlog@ + (n — x) log(1 — 6) (2.255) 


The score function is just the gradient of the log-likelihood: 


d x n-r 
0 4 16 — 2.2 
s(6|x) 0l) = 5- (2.256) 
The gradient of the score function is 
; x n-2£ 
ye ge ee 2.2 
LO ae (2.257) 
Hence the Fisher information is given by 
0 n-—né n n n 
F(6) = Ezo [-s'(6|x)] = <4 = = 2.2 
2.4.3.2 FIM for the Gaussian 
Consider a univariate Gaussian p(a|0) = N (x|u, v). We have 
j) 1 1 
£(0) = log p(x|0) = 5 (x — ps)? 5 log(v) 5 log (27) (2.259) 
v 
The partial derivatives are given by 
oe _ı se e 
ob 1 _, >» 1 _, 2a _3 9,1 _» 
= = 2.261 
ayo gt! OH a? a eet (2.261) 
oe 2 
= -v7 2.262 
i ci (2.262) 
and hence 


r= (flay ae a Pea) - C6) om 


2.4.3.3 FIM for logistic regression 


Consider ¢2-regularized binary logistic regression. The negative log joint has the following form: 


N 
T À 
E(w) = —log|p(y[X, w)p(w|d)] = ~w"X"y + X` log(1 +e”) + zuw (2.264) 


n=1 
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The derivative has the form 
VuwElw) = -X'y + X's + Aw (2.265) 


where sn = o(w' æn). The FIM is given by 


F(w) = Eycy|x,w,) [V E(w)] = X'AX + AL (2.266) 
where A is the N x N diagonal matrix with entries 


Ann = o(w' £n)(1 — o(w'an)) (2.267) 


2.4.4 Approximating KL divergence using FIM 


Mahalanobis distance based on the Fisher information can be viewed as an approximation to the KL 
divergence between two distributions, as we now show. 

Let pọ(x) and pø (x) be two distributions, where 8’ = 0 +ô. We can measure how close the second 
distribution is to the first in terms their predictive distribution (as opposed to comparing 0 and 6’ in 
parameter space) as follows: 


Dut (pe || Par) = Epo (x) [log po (æ) — log pø (æ)] (2.268) 


Let us approximate this with a second order Taylor series expansion: 


1 
Dkr (po || por) ~ —6'E[V log po (£)] — v i [V? log po (æ)] 6 (2.269) 
Since the expected score function is zero (from Equation (2.245)), the first term vanishes, so we have 
1 
Dra (po || por) = 58" F(8)5 (2.270) 
where F is the FIM 


F = —E [V? logpo(#)] = E [(V log po(æ))(V log po (2))"] (2.271) 


This result is the basis of the natural gradient method discussed in Section 6.4. 


2.4.5 Fisher information matrix for exponential family 


In this section, we discuss how to derive the FIM for an exponential family distribution with natural 
parameters 7. Recall from Equation (2.199) that the gradient of the log partition function is the 
expected sufficient statistics 


VnA(n) =E [T (x)] =m (2.272) 


and from Equation (2.230) that the gradient of the log likelihood is the statistics minus their expected 
value: 


Vn log p(æ|n) = T (x) — E [T (x)] (2.273) 
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1 
2 Hence the FIM wrt the natural parameters F} is given by 
3 
i A log p(#|n) A log p(x|n) 
4 F,,):; = Ente 2.274 
i (En aj p(æ|n) an; On; ( ) 
6 = Epel) (T (2): — mi)(T (E); — my)] (2.275) 
í = Cov [T(x);, T (£);] (2.276) 
8 
9 or, in short, 
10 
u F,, = Cov [T (æ)] (2.277) 
= Sometimes we need to compute the Fisher wrt the moment parameters m: 
14 ð log p(x|n) A log p(a|n) 
(Fm)ij = Ep(ajm (2.278) 
F j p(x|m) am; am; 
17 From the chain rule we have 
18 
= it it 
co DRP) q Plog pie) 88 (2.279) 
ða OB Oa 
20 
21 and hence 
ap", a 
23 Fs aa ele 2.280 
24 da "da ( ) 
25 Using the log trick 
26 
27 VE pa) [f (®)] = Epa) [f(@)V log p(æ)] (2.281) 
28 
29 and Equation (2.273) we have 
30 
3 ôm; _ ƏÆ[T(z):;] d log p(x|n) 
31 = = E |J (a); ———_ | = E|7 (a); (T (x); — mj 2.282 
no (e) sn (T (w)s(T (a); -= my) (2.282) 
33 = E [7 (a); T (x);] — E [T (æ);] mj = Cov [T (æ); T (x) 5] = (Fn )az (2.283) 
34 
35 and hence 
26 an =i 
38 
= so 
39 
40 ðn" ðN a pi p -1 
41 Fm = 57, Fag = Fp FF =F, = Cov [T(#)| (2.285) 
42 
43 2.5 Transformations of random variables 
44 
45 Suppose # ~ p,(x) is some random variable, and y = f(x) is some deterministic transformation of 
46 it. In this section, we discuss how to compute p,(y). 
47 
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2.5. TRANSFORMATIONS OF RANDOM VARIABLES 


Surjective Non-surjective 


Injective 


Non-injective 


Figure 2.11: Illustration of injective and surjective functions. 


2.5.1 Invertible transformations (bijections) 


Let f be a bijection that maps R” to R”. (A bijection is a function that is injective, surjective, 
and one-to-one, as illustrated in Figure 2.11; this means that the function has a well-defined inverse.) 
Suppose we want to compute the pdf of y = f(a). The change of variables formula tells us that 


Py(y) = Pa (f~*(y)) | det [J j- (y)] | (2.286) 


where J p-1(y) is the Jacobian of the inverse mapping f~t evaluated at y, and | det J| is the absolute 
value of the determinant of J. In other words, 


ðzı .., OW 
yı Yn 
Jj- (y) = : (2.287) 
Ən .,, Tn 
Oy OYn 


If the Jacobian matrix is triangular, the determinant reduces to a product of the terms on the main 
diagonal: 


n Ax; 
det(J) = Il Dy, (2.288) 
2.5.2 Monte Carlo approximation 


Sometime it is difficult to compute the Jacobian. In this case, we can make a Monte Carlo 
approximation, by drawing S samples x° ~ p(x), computing y* = f(x), and then constructing the 
empirical pdf 


S 
poly) == >_d(y- y) (2.289) 


For example, let x ~ N (6,1) and y = f(x), where f(x) = 
using Monte Carlo, as shown in Figure 2.12. 


easy We can approximate p(y) 
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Figure 2.12: Example of the transformation of a density under a nonlinear transform. Note how the mode of 
the transformed distribution is not the transform of the original mode. Adapted from Exercise 1.4 of [Bis06]. 
Generated by bayes_ change_of_var.ipynb. 


2.5.3 Probability integral transform 


Suppose that X is a random variable with cdf Py. Let Y(X) = Px(X) be a transformation of 
X. We now show that Y has a uniform distribution, a result known as the probability integral 
transform (PIT): 


Py (y) = Pr(¥ < y) = Pr(Px(X) < y) (2.290) 
= Pr(X < Pž'(y)) = Px(Px"(y)) =y (2.291) 


For example, in Figure 2.13, we show various distributions with pdf’s px on the left column. We 
sample from these, to get £n ~ Pr. Next we compute the empirical cdf of Y = Px (X), by computing 
Yn = Px (£n) and then sorting the values; the results, shown in the middle column, show that this 
distribution is uniform. We can also approximate the pdf of Y by using kernel density estimation; 
this is shown in the right column, and we see that it is (approximately) flat. 

We can use the PIT to test if a set of samples come from a given distribution using the Kol- 
mogorov—Smirnov test. To do this, we plot the empirical cdf of the samples and the theoretical 
cdf of the distribution, and compute the maximum distance between these two curves, as illustrated 
in Figure 2.14. Formally, the KS statistic is defined as 


Dn = max |P,(a) — P(x)| (2.292) 


where n is the sample size, P,, is the empirical cdf, and P is the theoretical cdf. The value D,, should 
approach 0 (as n — oo) if the samples are drawn from P. 

Another application of the PIT is to generate samples from a distribution: if we have a way to 
sample from a uniform distribution, u,, ~ Unif(0,1), we can convert this to samples from any other 
distribution with cdf Px by setting £n = Py tun). 


2.6 Markov chains 


Suppose that a; captures all the relevant information about the state of the system. This means it is 
a sufficient statistic for predicting the future given the past, i.e., 


P(Lt47|€t, £1:t—1) = p(#t47|xt) (2.293) 
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2.6. MARKOV CHAINS 


pdf(X) cdf(Y) pdf(Y) 
1.04 1.0 4 aam 
4 -| 
E 5054 5054 
A 24 a 
0-4 0.0 4 0.0 4 
T T T T T T T T 
0 10 20 0.0 0.5 1.0 0.0 0.5 1.0 
Tı yı yı 
1.04 1.0 | ———___ 
4 =a 
oe 7 = 
2 £ 0.54 20.54 
X94 A 
qeu 0.0 Ty 0.0 a E 
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0 
T2 y2 y2 
0.44 1.04 1.0 {4 wa~ 
£024 S054 S054 
Aa 
0.04 á i 0.0 44 i 0.04, 
—2.5 0.0 2.5 0.0 0.5 1.0 0.0 0.5 1.0 
T3 y3 y3 


Figure 2.13: Illustration of the probability integral transform. Left column: 3 different pdf’s for p(X) from 
which we sample zn ~ p(x). Middle column: empirical cdf of yn = Px(an). Right column: empirical 
pdf of plyn) using a kernel density estimate. Adapted from Figure 11.17 of [MKL11]. Generated by 
ecdf_ sample.ipynb. 


for any T > 0. This is called the Markov assumption. In this case, we can write the joint 
distribution for any finite length sequence as follows: 


T 
p(wr-r) = p(a1)p(w2|ae1)p(ag|@2)p(walars) -.. = par) [| pæle) (2.294) 
t=2 
This is called a Markov chain or Markov model. Below we cover some of the basics of this topic; 
more details on the theory can be found in [Kun20]. 


2.6.1 Parameterization 


In this section, we discuss how to represent a Markov model parametrically. 


2.6.1.1 Markov transition kernels 


The conditional distribution p(a,|a;_1) is called the transition function, transition kernel or 
Markov kernel. This is just a conditional distribution over the states at time t given the state at 
time t — 1, and hence it satisfies the conditions p(x,|a4-1) > 0 and [ae dx p(a, = x|a,-1) = 1. 

If we assume the transition function p(a;|a1.,_1) is independent of time, then the model is said 
to be homogeneous, stationary, or time-invariant. This is an example of parameter tying, 
since the same parameter is shared by multiple variables. This assumption allows us to model an 
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Figure 2.14: Illustration of the Kolmogorov-Smirnov statistic. The red line is a model CDF, the blue line 
is an empirical CDF, and the black arrow is the K-S statistic. From https: // en. wikipedia. org/ wiki/ 
Kolmogorov_ Smirnov_ test. Used with kind permission of Wikipedia author Bscan. 
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Figure 2.15: State transition diagrams for some simple Markov chains. Left: a 2-state chain. Right: a 3-state 
left-to-right chain. 


arbitrary number of variables using a fixed number of parameters. We will make the time-invariant 
assumption throughout the rest of this section. 


2.6.1.2 Markov transition matrices 


In this section, we assume that the variables are discrete, so X; € {1,..., K}. This is called a 
finite-state Markov chain. In this case, the conditional distribution p(X;|X;_1) can be written 
as a K x K matrix A, known as the transition matrix, where A;; = p(X; = j|X:-1 = i) is the 
probability of going from state į to state j. Each row of the matrix sums to one, )> j Ajj = 1, so this 
is called a stochastic matrix. 

A stationary, finite-state Markov chain is equivalent to a stochastic automaton. It is common 
to visualize such automata by drawing a directed graph, where nodes represent states and arrows 
represent legal transitions, i.e., non-zero elements of A. This is known as a state transition 
diagram. The weights associated with the arcs are the probabilities. For example, the following 
2-state chain 


Ave C ; a ; a a) (2.295) 
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2.6. MARKOV CHAINS 


is illustrated in Figure 2.15(a). The following 3-state chain 


Ay, Ai 0 
A=| 0 Ags Az (2.296) 
0 0 1 


is illustrated in Figure 2.15(b). This is called a left-to-right transition matrix. 
The A;; element of the transition matrix specifies the probability of getting from 7 to j in one step. 
The n-step transition matrix A(n) is defined as 


Aij(n) Ê p( Xian = jX = 4) (2.297) 
which is the probability of getting from i to j in exactly n steps. Obviously A(1) = A. The 
Chapman-Kolmogorov equations state that 


Aij(m+n) -So Aam m) Ax; (n) (2.298) 


In words, the probability of getting from i to 7 in m +n steps is just the probability of getting from 
i to k in m steps, and then from k to j in n steps, summed up over all k. We can write the above as 
a matrix multiplication 


A(m+n) = A(m)A(n) (2.299) 
Hence 
A(n) =A A(n—1)=AA A(n—2)=---= A” (2.300) 


Thus we can simulate multiple steps of a Markov chain by “powering up” the transition matrix. 


2.6.1.3 Higher-order Markov models 


The first-order Markov assumption is rather strong. Fortunately, we can easily generalize first-order 
models to depend on the last n observations, thus creating a model of order (memory length) n 


T 
plærT) =p(tin) [[ p(welarr—n:t—1) (2.301) 
t=n+1 
This is called a Markov model of order n. If n = 1, this is called a bigram model, since we 
need to represent pairs of characters, p(a,|a4~-1). If n = 2, this is called a trigram model, since we 
need to represent triples of characters, p(a:|a@4~1, 4-2). In general, this is called an n-gram model. 
Note, however, we can always convert a higher order Markov model to a first order one by defining 
an augmented state space that contains the past n observations. For example, if n = 2, we define 
Ti = (£1, £4) and use 


T 


P(@1-7) = p(&2) | | p11) = p(ar, x2) | | p(welarr—r, 1-2) (2.302) 
= t=3 


Therefore we will just focus on first-order models throughout the rest of this section. 
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christians first inhabit wherein thou hast forgive if a man childless and of laying of core these 
are the heavens shall reel to and fro to seek god they set their horses and children of israel 


Figure 2.16: Example output from an 10-gram character-level Markov model trained on the King James Bible. 
The prefix “christians” is given to the model. Generated by ngram_ character_ demo. ipynb. 


CNET OE TE E TEE 3 
DONX a< ODLA Í an c3 an Tno Aw 


etainoshrdimucfwgypbvkxzjq 


(a) (b) 


etainoshrdimucfwgypbvkxzjq 


Figure 2.17: (a) Hinton diagram showing character bigram counts as estimated from H. G. Wells’ book 
The Time Machine. Characters are sorted in decreasing unigram frequency; the first one is a space character. 
The most frequent bigram is ’e-’, where - represents space. (b) Same as (a) but each row is normalized across 
the columns. Generated by bigram_hinton_ diagram.ipynb. 


2.6.2 Application: Language modeling 


One important application of Markov models is to create language models (LM), which are models 
which can generate (or score) a sequence of words. When we use a finite-state Markov model with 
a memory of length m = n — 1, it is called an n-gram model. For example, if m = 1, we get a 
unigram model (no dependence on previous words); if m = 2, we get a bigram model (depends 
on previous word); if m = 3, we get a trigram model (depends on previous two words); etc. See 
Figure 2.16 for some generated text. 

These days, most LMs are built using recurrent neural nets (see Section 16.3.4), which have 
unbounded memory. However, simple n-gram models can still do quite well when trained with enough 
data [Che17]. 

Language models have various applications, such as priors for spelling correction (see Section 29.3.3) 
or automatic speech recognition. In addition, conditional language models can be used to generate 
sequences given inputs, such as mapping one language to another, or an image to a sequence, etc. 


2.6.3 Parameter estimation 


In this section, we discuss how to estimate the parameters of a Markov model. 
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2.6. MARKOV CHAINS 


2.6.3.1 Maximum likelihood estimation 


The probability of any particular sequence of length T is given by 


p(x1.7|@) = n(x1)A(z1, 22)... A(@r-1, 27) (2.303) 
K T K K 

= [[ í; I(z1=5) JIII IE: , p) E5 zt—1=Ĵ) (2.304) 
j=1 t=2 j=1 k=1 

Hence the log-likelihood of a set of sequences D = (æ1,..., £N), where £; = (Xi1,...,%i,7,) is a 


sequence of length T;, is given by 


log p(D|@) = 3 log p(a;|0) = XL N} logn; + X X. Nip log Ajr (2.305) 
; 
where we define the following counts: 
N N T,-1 
N EN Ieas 5), Nat >) YO Mie =j ti =k), Nj = 2 (2.306) 
i=1 i=1 t=1 


By adding Lagrange multipliers to enforce the sum to one constaints, one can show (see e.g., [Mur22, 
Sec 4.2.4]) that the MLE is given by the normalized counts: 
N; Ni; 


T Â . E EE 
LS oF E NE » Ajk N; (2.307) 


We often replace N, ie which is how often symbol j is seen at the start of a sequence, by Nj, which is 
how often symbol j is seen anywhere in a sequence. This lets us estimate parameters from a single 
sequence. 

The counts N; are known as unigram statistics, and Nj; are known as bigram statistics. For 
example, Figure 2.17 shows some 2-gram counts for the characters {a,...,z,—} (where - represents 
space) as estimated from H. G. Wells’ book The Time Machine. 


2.6.3.2 Sparse data problem 


When we try to fit n-gram models for large n, we quickly encounter problems with overfitting due 
to data sparsity. To see that, note that many of the estimated counts Njg will be 0, since now j 
indexes over discrete contexts of size K"~1, which will become increasingly rare. Even for bigram 
models (n = 2), problems can arise if K is large. For example, if we have K ~ 50,000 words in our 
vocabulary, then a bi-gram model will have about 2.5 billion free parameters, corresponding to all 
possible word pairs. It is very unlikely we will see all of these in our training data. However, we do 
not want to predict that a particular word string is totally impossible just because we happen not to 
have seen it in our training text — that would be a severe form of overfitting.® 


8. A famous example of an improbable, but syntactically valid, English word string, due to Noam Chomsky [Ch057], 
is “colourless green ideas sleep furiously”. We would not want our model to predict that this string is impossible. Even 
ungrammatical constructs should be allowed by our model with a certain probability, since people frequently violate 
grammatical rules, especially in spoken language. 
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(b) 
(a) 


Figure 2.18: Some Markov chains. (a) A 3-state aperiodic chain. (b) A reducible 4-state chain. 


A “brute force” solution to this problem is to gather lots and lots of data. For example, Google 
has fit n-gram models (for n = 1 : 5) based on one trillion words extracted from the web. Their data, 
which is over 100GB when uncompressed, is publically available.” Although such an approach can be 
surprisingly successful (as discussed in [HNP09J), it is rather unsatisfying, since humans are able to 
learn language from much less data (see e.g., [TX00]). 


2.6.3.3 MAP estimation 


A simple solution to the sparse data problem is to use MAP estimation with a uniform Dirichlet 
prior, Aj; ~ Dir(a1). In this case, the MAP estimate becomes 


Nita 


Ap STS 
ik N;+Ka 


(2.308) 
If a = 1, this is called add-one smoothing. 

The main problem with add-one smoothing is that it assumes that all n-grams are equally likely, 
which is not very realistic. We discuss a more sophisticated approach, based on hierarchical Bayes, 
in Section 3.8.3. 


2.6.4 Stationary distribution of a Markov chain 


Suppose we continually draw consecutive samples from a Markov chain. In the case of a finite state 
space, we can think of this as “hopping” from one state to another. We will tend to spend more 
time in some states than others, depending on the transition graph. The long term distribution over 
states is known as the stationary distribution of the chain. In this section, we discuss some of the 
relevant theory. In Chapter 12, we discuss an important application, known as MCMC, which is a way 
to generate samples from hard-to-normalize probability distributions. In Supplementary Section 2.2 
we consider Google’s PageRank algorithm for ranking web pages, which also leverages the concept of 
stationary distributions. 


2.6.4.1 What is a stationary distribution? 


Let Aij = p(X: = j|Xt-1 = i) be the one-step transition matrix, and let m(j) = p(X: = j) be the 
probability of being in state j at time t. 


9. See http://googleresearch. blogspot .com/2006/08/all-our-n-gram-are-belong-to-you.html for details. 
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If we have an initial distribution over states of mo, then at time 1 we have 


or, in matrix notation, mı = 7 9A, where we have followed the standard convention of assuming m is 
a row vector, so we post-multiply by the transition matrix. 

Now imagine iterating these equations. If we ever reach a stage where m = mA, then we say we 
have reached the stationary distribution (also called the invariant distribution or equilibrium 
distribution). Once we enter the stationary distribution, we will never leave. 

For example, consider the chain in Figure 2.18(a). To find its stationary distribution, we write 


1— Ajo = Aj3 Ai2 Aj3 
(mı T2 T3) = (mı T2 T3) Ag 1- Ao = Ao3 Ag3 (2.310) 
A31 A32 1 — Az; — A32 


Hence m™ (Aja + A13) = Ag + T3 Á31. In general, we have 


mi Y Aij = X rj Aji (2.311) 
j+i At 
In other words, the probability of being in state ¿į times the net flow out of state i must equal the 
probability of being in each other state j times the net flow from that state into i. These are called 
the global balance equations. We can then solve these equations, subject to the constraint that 
> jTj =1, to find the stationary distribution, as we discuss below. 


2.6.4.2 Computing the stationary distribution 


To find the stationary distribution, we can just solve the eigenvector equation A'v = v, and then 
to set m = v', where v is an eigenvector with eigenvalue 1. (We can be sure such an eigenvector 
exists, since A is a row-stochastic matrix, so A1 = 1; also recall that the eigenvalues of A and AT 
are the same.) Of course, since eigenvectors are unique only up to constants of proportionality, we 
must normalize v at the end to ensure it sums to one. 

Note, however, that the eigenvectors are only guaranteed to be real-valued if all entries in the 
matrix are strictly positive, A;; > 0 (and hence A;; < 1, due to the sum-to-one constraint). A more 
general approach, which can handle chains where some transition probabilities are 0 or 1 (such as 
Figure 2.18(a)), is as follows. We have K constraints from m(I — A) = 0x x, and 1 constraint from 
w1K x1 = 1. Hence we have to solve mM = r, where M = [I — A, 1] is a K x (K +1) matrix, and 
r = (0,0,...,0,1] isa 1 x (K +1) vector. However, this is overconstrained, so we will drop the last 
column of I — A in our definition of M, and drop the last 0 from r. For example, for a 3 state chain 
we have to solve this linear system: 


(mı wa T3) — Ao 1— Ago 1 = (0 0 1) (2.312) 
— A31 —Á32 1 
For the chain in Figure 2.18(a) we find m = [0.4, 0.4, 0.2]. We can easily verify this is correct, since 
T=TÀ. 
Unfortunately, not all chains have a stationary distribution, as we explain below. 
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2.6.4.3 When does a stationary distribution exist? 


Consider the 4-state chain in Figure 2.18(b). If we start in state 4, we will stay there forever, since 
4 is an absorbing state. Thus m = (0,0,0,1) is one possible stationary distribution. However, 
if we start in 1 or 2, we will oscillate between those two states for ever. So m = (0.5, 0.5, 0,0) is 
another possible stationary distribution. If we start in state 3, we could end up in either of the above 
stationary distributions with equal probability. The corresponding transition graph has two disjoint 
connected components. 

We see from this example that a necessary condition to have a unique stationary distribution is 
that the state transition diagram be a singly connected component, i.e., we can get from any state to 
any other state. Such chains are called irreducible. 

Now consider the 2-state chain in Figure 2.15(a). This is irreducible provided a, 8 > 0. Suppose 
a= 6 = 0.9. It is clear by symmetry that this chain will spend 50% of its time in each state. Thus 
m = (0.5,0.5). But now suppose a = 8 = 1. In this case, the chain will oscillate between the two 
states, but the long-term distribution on states depends on where you start from. If we start in state 
1, then on every odd time step (1,3,5,...) we will be in state 1; but if we start in state 2, then on 
every odd time step we will be in state 2. 

This example motivates the following definition. Let us say that a chain has a limiting distri- 
bution if 7; = limp. Aj, exists and is independent of the starting state 7, for all j. If this holds, 
then the long-run distribution over states will be independent of the starting state: 


p(X =j) = > P(Xo = i) Aij (t) > T; as t + 00 (2.313) 


Let us now characterize when a limiting distribution exists. Define the period of state i to be 
d(i) = gcd{t : Au(t) > 0}, where gcd stands for greatest common divisor, i.e., the largest integer 
that divides all the members of the set. For example, in Figure 2.18(a), we have d(1) = d(2) = 
gcd(2,3,4,6,...) = 1 and d(3) = gcd(3,5,6,...) = 1. We say a state i is aperiodic if d(i) = 1. (A 
sufficient condition to ensure this is if state 7 has a self-loop, but this is not a necessary condition.) 
We say a chain is aperiodic if all its states are aperiodic. One can show the following important 
result: 


Theorem 2.6.1. Every irreducible (singly connected), aperiodic finite state Markov chain has a 
limiting distribution, which is equal to n, its unique stationary distribution. 


A special case of this result says that every regular finite state chain has a unique stationary 
distribution, where a regular chain is one whose transition matrix satisfies AR; > 0 for some integer 
n and all i, j, i.e., it is possible to get from any state to any other state in n steps. Consequently, after 
n steps, the chain could be in any state, no matter where it started. One can show that sufficient 
conditions to ensure regularity are that the chain be irreducible (singly connected) and that every 
state have a self-transition. 

To handle the case of Markov chains whose state space is not finite (e.g, the countable set of all 
integers, or all the uncountable set of all reals), we need to generalize some of the earlier definitions. 
Since the details are rather technical, we just briefly state the main results without proof. See e.g., 
[GS92] for details. 

For a stationary distribution to exist, we require irreducibility (singly connected) and aperiodicity, 
as before. But we also require that each state is recurrent, which means that you will return to 
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that state with probability 1. As a simple example of a non-recurrent state (i.e., a transient state), 
consider Figure 2.18(b): states 3 is transient because one immediately leaves it and either spins 
around state 4 forever, or oscillates between states 1 and 2 forever. There is no way to return to 
state 3. 

It is clear that any finite-state irreducible chain is recurrent, since you can always get back to where 
you started from. But now consider an example with an infinite state space. Suppose we perform 
a random walk on the integers, ¥ = {...,—2,—1,0,1,2,...}. Let A;i+1 = p be the probability of 
moving right, and A; ;¿—1 = 1 — p be the probability of moving left. Suppose we start at X, = 0. If 
p > 0.5, we will shoot off to +00; we are not guaranteed to return. Similarly, if p < 0.5, we will shoot 
off to —oo. So in both cases, the chain is not recurrent, even though it is irreducible. If p = 0.5, we 
can return to the initial state with probability 1, so the chain is recurrent. However, the distribution 
keeps spreading out over a larger and larger set of the integers, so the expected time to return is 
infinite. This prevents the chain from having a stationary distribution. 

More formally, we define a state to be non-null recurrent if the expected time to return to this 
state is finite. We say that a state is ergodic if it is aperiodic, recurrent and non-null,. We say that 
a chain is ergodic if all its states are ergodic. With these definitions, we can now state our main 
theorem: 


Theorem 2.6.2. Every irreducible, ergodic Markov chain has a limiting distribution, which is equal 
to n, its unique stationary distribution. 


This generalizes Theorem 2.6.1, since for irreducible finite-state chains, all states are recurrent and 
non-null. 


2.6.4.4 Detailed balance 


Establishing ergodicity can be difficult. We now give an alternative condition that is easier to verify. 
We say that a Markov chain A is time reversible if there exists a distribution m such that 


Ti Aig = TjÁji (2.314) 


These are called the detailed balance equations. This says that the flow from 7 to 7 must equal 
the flow from j to i, weighted by the appropriate source probabilities. 
We have the following important result. 


Theorem 2.6.3. If a Markov chain with transition matrix A is regular and satisfies the detailed 
balance equations wrt distribution n, then n is a stationary distribution of the chain. 


Proof. To see this, note that 
So mAs =o mj Ayi = Tj Y Aji = Tj (2.315) 


and hence 7 = Ar. 


Note that this condition is sufficient but not necessary (see Figure 2.18(a) for an example of a 
chain with a stationary distribution which does not satisfy detailed balance). 
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Samples from P and Q __ Samples from P and Q 


1 
05! 05: 
0} @ oe owo ar © +b% + 0-40 084 06406 © oF 
| 
-0.5 + | -0.5 } 
4 ; 5 , . , 
0 0.2 0.4 0.6 0.8 i 0 0.2 0.4 0.6 0.8 1 
(a) (b) 


Figure 2.19: Samples from two distributions which are (a) different and (b) similar. From a figure from 
[GSJ19]. Used with kind permission of Arthur Gretton. 


2.7 Divergence measures between probability distributions 


In this section, we discuss various ways to compare two probability distributions, P and Q, defined 
on the same space. For example, suppose the distributions are defined in terms of samples, ¥ = 
{a1,...,un}~ P and 4’ = {21,..., m} ~ Q. Determining if the samples come from the same 
distribution is known as a two-sample test (see Figure 2.19 for an illustration). This can be 
computed by defining some suitable divergence metric D(P, Q) and comparing it to a threshold. 
(We use the term “divergence” rather than distance since we will not require D to be symmetric.) 
Alternatively, suppose P is an empirical distribution of data, and Q is the distribution induced 
by a model. We can check how well the model approximates the data by comparing D(P,Q) toa 
threshold; this is called a goodness-of-fit test. 

There are two main ways to compute the divergence between a pair of distributions: in terms of 
their difference, P — Q (see e.g., [Sug+13]) or in terms of their ratio, P/Q (see e.g., [SSK12]). We 
briefly discuss both of these below. (Our presentation is based, in part, on [GSJ19].) 


2.7.1 f-divergence 


In this section, we compare distributions in terms of their density ratio r(x) = p(x)/q(a). In 
particular, consider the f-divergence [Mor63; AS66; Csi67; LV06; CS04], which is defined as follows: 


Ds(p\|q) = poms =>) dx (2.316) 


where f : Ry —> R is a convex function satisfying f(1) = 0. From Jensen’s inequality (Section 5.1.2.2), 
it follows that D(p||q) > 0, and obviously D+ (p||p) = 0, so Dy is a valid divergence. Below we discuss 
some important special cases of f-divergences. (Note that f-divergences are also called ¢-divergences.) 
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2.7. DIVERGENCE MEASURES BETWEEN PROBABILITY DISTRIBUTIONS 


(a) a = =œ (b)a=0 (c) a=0.5 (d)a=1 (e) a = œ 


Figure 2.20: The Gaussian q which minimizes a-divergence to p (a misture of two Gaussians), for varying a. 
From Figure 1 of [Min05]. Used with kind permission of Tom Minka. 


2.7.1.1 KL divergence 


Suppose we compute the f-divergence using f(r) = rlog(r). In this case, we get a quantity called the 
Kullback Leibler divergence, defined as follows: 


Dut (p || a) = Jra P2) dg (2.317) 


See Section 5.1 for more details. 


2.7.1.2 Alpha divergence 


If f(x) = (1 — r), the f-divergence becomes the the alpha divergence [Ama09], which is 


l1—a? 
as follows: 


4 
D2 (lla) 2 Tae (1 = [Aala aa) (2.318) 


where we assume a # +1. Another common parameterization , and the one used by Minka in [Min05], 
is as follows: 


Ds (ola) = aay (1- [roaa eaa) (2.319) 


This can be converted to Amari’s notation using D4, = DM where a’ = 2a — 1. (We will use the 
Minka convention.) 

We see from Figure 2.20 that as a + —oo, q prefers to match one mode of p, whereas when a — ov, 
q prefers to cover all of p. More precisely, one can show that as a > 0, the alpha-divergence tends 
towards Dxz (q || p), and as a — 1, the alpha-divergence tends towards Dxz (p || q). Also, when 
a = 0.5, the alpha-divergence equals the Hellinger distance (Section 2.7.1.3). 
2.7.1.3 Hellinger distance 


The (squared) Hellinger distance is defined as follows: 


Dollo) £ 5 f (a)? - aa)+) de =1- f Vee) (2.320) 


This is a valid distance metric, since it is symmetric, non-negative and satisfies the triangle inequality. 
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Smooth function 


(a) 


Smooth function 


0.5 


-0.5 


(b) 


Figure 2.21: A smooth witness function for comparing two distributions which are (a) different and (b) 
similar. From a figure from [GSJ19]. Used with kind permission of Arthur Gretton. 


We see that this is equal (up to constant factors) to the f-divergence with f(r) = (yr — 1)?, since 


pe q(x) (32 — i) = foa q(x) Garces 


2.7.1.4 Chi-squared distance 


The chi-squared distance x? is defined by 


A , a 


xX (p,q) Ê 5 del 


T 


(2.322) 


This is equal (up to constant factors) to an f-divergence where f(r) = (r — 1)?, since 


[e (4S - i) = f dæ q) (pe = [ae goe) 232) 


q(x 


2.7.2 Integral probability metrics 


q(x) z) 


In this section, we compute the divergence between two distributions in terms of P — Q using an 
integral probability metric or IPM [Sri+09]. This is defined as follows: 


Dz(P, Q) = sup | Gta [f(x)] = i g(x’) [F(x 
fEF 


JJ] (2.324) 


where F is some class of “smooth” functions. The function f that maximizes the difference between 
these two expectations is called the witness function. See Figure 2.21 for an illustration. 

There are several ways to define the function class F. One approach is to use an RKHS, defined in 
terms of a positive definite kernel function; this gives rise to the method known as maximum mean 
discrepancy or MMD. See Section 2.7.3 for details. 
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Another approach is to define F to be the set of functions that have bounded Lipschitz constant, 
ie., F = {|lfllz < 1}, where 


x) — f(a’ 
|fllz = sup e =e (2.325) 
axa! |æ T | 
The IPM in this case is equal to the Wasserstein-1 distance 
Wi (P, Q) = sup | Up (a) [f(x)] za g(a’) [f(2’)] | (2.326) 


IFIle <1 


See Section 6.10.2.4 for details. 


2.7.3 Maximum mean discrepancy (MMD) 


In this section, we describe the maximum mean discrepancy or MMD method of [Gre+12], 
which defines a discrepancy measure D(P, Q) using samples from the two distributions. The samples 
are compared using positive definite kernels (Section 18.2), which can handle high-dimensional 
inputs. This approach can be used to define two-sample tests, and to train implicit generative models 
(Section 26.2.4). 


2.7.3.1 MMD as an IPM 


The MMD is an integral probability metric (Section 2.7.2) of the form 


MMD(P,Q;F)= sup [Ep [f(@)] — Esan [F] (2.327) 
fEF:IIfII<1 


where F is an RKHS (Section 18.3.7.1) defined by a positive definite kernel function K. We can 
represent functions in this set as an infinite sum of basis functions 


f(a) = (f ole) r = X figla) (2.328) 
t=1 


We restrict the set of witness functions f to be those that are in the unit ball of this RKHS, so 
fll = Does fE < 1. 


By the linearity of expectation, we have 


where up is called the kernel mean embedding of distribution P [Mua+17]. Hence 
Up — Ha 


MMD(P,Q;F) = sup (f, up — Ha)F 


-LP rg. (2.330) 
IIfII<1 |up — Holl 


since the unit vector f that maximizes the inner product is parallel to the difference in feature means. 

To get some intuition, suppose (2) = [x, x7]. In this case, the MMD computes the difference in 
the first two moments of the two distributions. This may not be enough to distinguish all possible 
distributions. However, using a Gaussian kernel is equivalent to comparing two infinitely large feature 
vectors, as we show in Section 18.2.6, and hence we are effectively comparing all the moments of the 
two distributions. Indeed, one can show that MMD=0 iff P = Q, provided we use a non-degenerate 
kernel. 
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2.7.3.2 Computing the MMD using the kernel trick 


i this section, we describe how to compute Equation (2.330) in practice, given two sets of samples, 

= = {en}ng , and 4! = {x1}, where æn ~ P and z’, ~ Q. Let up = Spe 1P(#n) and 
Ho = m LyM 1 (£in) be empirical estimates of the kernel mean embeddings of the two distributions. 
Then the squared MMD is given by 


E E le EE EE 
MMD? (X, X’) £ | D 7 b(an) — i D AEn) (2.331) 
n=1 m=1 
1 N N 2 N M 
= ya de De Pn) Elen) = Hag De De AEn) O(n) 
n=l =l n=1m=1 


tad ¥ beh) een) (2.332) 


m=1m'/=1 


Since Equation (2.332) only involves inner products of the feature vectors, we can use the kernel 
trick (Section 18.2.5) to rewrite the above as follows: 


N M “uM 
2 
MMD (X, 4’) NZ - wh UK Ln, Ln!) — WM uN Ln, T m) + ape LD, Clem Ew) 


(2.333) 


2.7.3.3 Linear time computation 


The MMD takes O(N?) time to compute, where N is the number of samples from each distribution. 
In [Chw+15], they present a different test statistic called the unnormalized mean embedding or 
UME, that can be computed in O(N) time. 

The key idea is to notice that evaluating 


witness? (v) = (uov) — uplv))? (2.334) 


at a set of test locations v1,...,u,z is enough to detect a difference between P and Q. Hence we 
define the (squared) UME as follows: 


J 
1 2 
UME?(P, Q) = J >I up(vj)— uolv)] (2.335) 
j=1 


where pp(v) = Epa) (K(x, v)] can be estimated empirically in O(N) time, and similarly for uo (v). 

A normalized version of UME, known as NME, is presented in |Jit+16]. By maximizing NME wrt 
the locations vj, we can maximize the statistical power of the test, and find locations where P and 
Q differ the most. This provides an interpretable two-sample test for high dimensional data. 


2.7.3.4 Choosing the right kernel 


The effectiveness of MMD (and UME) obviously crucially depends on the right choice of kernel. Even 
for distiguishing 1d samples, the choice of kernel can be very important. For example, consider a 
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0.5 0.5 
€ o eo amaw ee = 0 
0.5 0.5 
-1 1 -1 -1 1 1 
1 -0.5 0 0.5 -1 -0.5 0 0.5 -1 -0.5 0 0.5 
x x x 
(a) (b) (c) 


Figure 2.22: Effect of bandwidth parameter o on the witness function defined by a Gaussian kernel. From a 
figure from [GSJ19]. Used with kind permission of Dougal Sutherland. 


Gaussian kernel, Ko (a, 2’) = exp(—542||a — 2’||*). The effect of changing o in terms of the ability 
to distinguish two different sets of 1d samples is shown in Figure 2.22. Fortunately, the MMD is 
differentiable wrt the kernel parameters, so we can choose the optimal g? so as to maximize the power 
of the test [Sut+17]. (See also [Fla+16] for a Bayesian approach, which maximizes the marginal 
likelihood of a GP representation of the kernel mean embedding.) 

For high-dimensional data such as images, it can be useful to use a pre-trained CNN model as a 
way to compute low-dimensional features. For example, we can define K(x, x’) = K,(h(a), h(x’)), 
where h is some hidden layer of a CNN. such as the “Inception” model of [Sze+15]. The resulting 
MMD metric is known as the kernel inception distance [Biń+18]. This is similar to the Frechet 
inception distance [Heu+ 17a], but has nicer statistical properties, and is better correlated with human 
perceptual judgement [Zho+19a]. 


2.7.4 Total variation distance 


The total variation distance between two probability distributions is defined as follows: 
xd 1 
Drv(p,4) ê 5llp— ali = 5 | WE) - a(@) law (2.336) 


This is equal to an f-divergence where f(r) = |r — 1|/2, since 


Lf pla) Lf qyP@)— 92) ay — 2 
5 | ae) Bs -1ds = 5 f ae) Jaw = 5 | Pæ) — a(e)lae (2.337) 


q(x) 


One can also show that the TV distance is an integral probability measure. In fact, it is the only 
divergence that is both an IPM and an f-divergence [Sri+09]. See Figure 2.23 for a visual summary. 


2.7.5 Density ratio estimation using binary classifiers 


In this section, we discuss a simple approach for comparing two distributions that turns out to be 
equivalent to IPMs and f-divergences. 
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eera prob. metri, divergence, 


x 


wasserstein Hellinger 


KL 


Dx (P,Q) 
= sup |Ex~pg(X) — Ey~aa(Y) 
gEH 


Dg(P;Q) 
“fo)e 


Pearson chi? 


Figure 2.23: Summary of the two main kinds of divergence measures between two probability distributions P 
and Q. From a figure from [GSJ19]. Used with kind permission of Arthur Gretton. 


Consider a binary classification problem in which points from P have label y = 1 and points from 
Q have label y = 0, i.e., P(x) = p(axl|y = 1) and Q(x) = p(aly = 0). Let p(y = 1) = 7 be the class 
prior. By Bayes’ rule, the density ratio r(x) = P(x)/Q(a) is given by 


_ p(zly=1) _ py= 

Q(z) p(aly=0)  pty=1) 6G = 0) (2.338) 
_ ply=1|a)1—a 
~ p(y =O0la) 7 (2.339) 


If we assume 7 = 0.5, then we can estimate the ratio r(a) by fitting a binary classifier or discriminator 
h(a) = p(y = 1\x) and then computing r = h/(1 — h). This is called the density ratio estimation 
or DRE trick. 

We can optimize the classifer h by minimizing the risk (expected loss). For example, if we use 
log-loss, we have 


R(h) = Ep(aly)yp(y) y log h(x) — (1 — y) log(1 — h(æ))] (2.340) 
= TE p(x) [— log h(a)] + (1 — r) Q(x) |- log(1 — h(x))] (2.341) 


We can also use other loss functions L(y, h(a)) (see Section 26.2.2). 

Let Rf, = infper R(h) be the minimum risk achievable for loss function Z, where we minimize 
over some function class F.!° In [NWJ09], they show that for every f-divergence, there is a loss 
function £ such that —Dp(P,Q) = Rf.. For example (using the notation 7 € {—1,1} instead of 
y € {0,1}), total-variation distance corresponds to hinge loss, (J, h) = max(0, 1 — yh); Hellinger 
distance corresponds to exponential loss, ¢(g,h) = exp(—gh); and x? divergence corresponds to 
logistic loss, €(g, h) = log(1 + exp(—gh)). 

We can also establish a connection between binary classifiers and IPMs [Sri+09]. In particular, let 


10. If P is a fixed distribution, and we minimize the above objective wrt h, while also maximizing it wrt a model 
Q(a), we recover a technique known as a generative adversarial network for fitting an implicit model to a distribution 
of samples P (see Chapter 26 for details). However, in this section, we assume Q is known. 
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L(g, h) = —29h, and p(y = 1) = p(y = —1) = 0.5. Then we have 


Rye = int | Ug, h(ee))p(e|a)p(G)deag 
= inf 0.5 | ¢(1,h(e))p(aly = Ide +0.5 | &(-1,h@))p(ala 
= inf | hæ)Qlæ)dæ - | hw) P(w)de 
= sup- | h(x)Q(e)de+ | h(w)P(w)de 


(2.342) 
(2.343) 
(2.344) 


(2.345) 


which matches Equation (2.324). Thus the classifier plays the same role as the witness function. 
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3 Bayesian statistics 


3.1 Introduction 


Probability theory (which we discussed in Chapter 2) is concerned with the forwards mapping from 
parameters 0 to data x. (In the conditional setting, the forwards mapping is from (8, x) to y, but we 
mostly focus on the unconditional setting for notational simplicity.) Statistics is concerned with the 
inverse problem, in which we want to infer the unknown parameters 0 given observations a. (If we 
have multiple independent samples drawn from 0, we denote the dataset by D = {£n :n=1: N}.) 
Indeed, statistics was originally called inverse probability theory. Nowadays, there are two main 
approaches to statistics, frequentist statistics and Bayesian statistics, as we discuss below. 


3.1.1 Frequentist statistics 


The frequentist approach to statistics (also called classical statistics or orthodox statistics) is 
based on the concept of repeated trials. More precisely, suppose we have an estimator, such as 
the maximum likelihood estimator, @(D) = argmaxg p(D|@). Now imagine what would happen if 
the estimator were applied to multiple random datasets of the same size, drawn from the same but 
unknown “true distribution”, p*(D). The resulting distribution of estimated values, {@(D’) : D! ~ p*}, 
is called the sampling distribution of the estimator. The variability of this distribution can be 
regarded as a measure of uncertainty about the value that was estimated from the dataset that was 
actually observed, Ê = O(Dops): (For more details, see e.g., Section 4.7 of the prequel to this book, 
[Mur22].) 

In the machine learning literature, it is common to describe any method that computes a point 
estimate of the parameters (e.g., the MLE or the MAP estimate) as a “frequentist” method, but this 
is incorrect. It is the use of a sampling distribution to represent uncertainty that makes a method 
frequentist. If we ignore uncertainty, we are just performing an optimization problem, which is 
neither frequentist nor Bayesian. 


3.1.2 Bayesian statistics 


In the Bayesian approach to statistics, we treat the unknown parameters @ just like any other 
unknown random variable. Hence we can model it with a probability distribution. The observed 
data is considered fixed, and we do not need to worry about hypothetical repeated random trials 
with different data. We represent knowledge about the possible values of 0 given the data using a 
(conditional) probability distribution, p(@|D). 
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To compute this distribution, we start by specifying our initial knowledge or beliefs using a prior 
distribution, p(@). Our assumptions about how the data depends on the parameters are captured 
in the likelihood function p(D|6@). We can then combine these using Bayes’ rule to compute 
the posterior distribution, which represents our knowledge about the parameters after seeing the 
data: 

p(9)p(D|8) p(9)p(D|6) 


MOP)" p(D) T Jooo a 


where the quantity p(D) is called the marginal likelihood or evidence. The task of computing 
this posterior is called Bayesian inference, posterior inference or just inference. We discuss 
different algorithms for solving this problem in Part II. For more details, see e.g., [Gel+14a; MKL21; 
MFR20]. 


3.1.3 Arguments for the Bayesian approach 


The Bayesian approach is more general than the frequentist approach, since it can be applied to 
problems that don’t repeat (e.g., we can ask what is the probability that the ice caps will melt 
by 2030, which is not a meaningful question in the frequentist world view.) In addition, it avoids 
certain conceptual pathologies that plague frequentist statistics (see discussions in e.g., [Efr86; Jay03; 
Cla21]). Furthermore, the Bayesian approach is widely used in engineering and business, and there is 
also considerable evidence that it is used by humans and other animals [MKG21]!. There are also 
some more technical reasons for using the Bayesian approach, which we discuss below. 


3.1.3.1 De Finetti’s theorem 


The Bayesian approach is based on the idea of putting a prior on our parameters, and then using this 
extended model to represent our data. In this section, we justify why this is a reasonable thing to do. 

First, a definition. We say that a sequence of random variables (£1, £2,...) is infinitely ex- 
changeable if, for any n, the joint probability p(æŒ1,..., £n) is invariant to permutation of the 
indices. That is, for any permutation 7, we have 


p(@1,.--,kn) = p(La,,---, Ern) (3.2) 
Note that iid implies exchangeability but not vice versa. For example, suppose (£1,..., £n) isa 
set of MNIST images, and let a be a background image. The sequence (£o + £1, ..., Lo + £n) is 


infinitely exchangeable but not iid, since all the variables share a hidden common factor, namely the 
background æo. Thus the more examples we see, the better we will be able to estimate the shared 
£o, and thus the better we can predict future elements. 

Thus the key advantage of working with the exchangeability assumption is that it allows us to 
learn from similarities between data points. In fact, one can show the following result: 


Theorem 3.1.1 (de Finetti’s theorem). A sequence of random variables (£1, @2,...) is infinitely 


1. We will see that Bayesian inference can be computationally expensive. Brains have evolved to use various shortcuts 
to make the Bayesian computations efficient, see e.g., [Lak+17; Gri20]. 
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exchangeable iff, for all n, we have 
p(t...) = fT] p(ail@)(0)a0 (3.3) 
i=1 


where O is some hidden common random variable (possibly infinite dimensional). That is, x; are tid 
conditional on 0. 


We often interpret 0 as a parameter. The theorem tells us that, if our data is exchangeable, then 
there must exist a parameter 0, and a likelihood p(a;|@), and a prior p(@). Thus the Bayesian 
approach follows automatically from exchangeability [O’NO9]J. 


3.1.3.2 The Dutch book theorem 


Bayesian inference forms the foundation of Bayesian decision theory, which we discuss in Section 34.1. 
It can be shown that any other method for representing uncertainty is guaranteed to cause a decision 
maker to lose money if deployed in the context of a gambling game. This is called the Dutch book 
theorem. For details, see e.g., [H4j08]. Of course, this argument extends beyond gambling and applies 
to any form of decision making under uncertainty. 


3.1.3.3 Online learning 


The posterior captures our (un)certainty about the parameter given the data and whatever prior 
knowledge we had. Once we have computed it, we can “throw away” the data. This makes Bayesian 
inference particularly well-suited to online learning, where the dataset grows without bound, since 
we can recursively update our belief state: 


P(O|Dit) x p(Dr|@)p(O|Di2-1) (3.4) 


Being able to rapidly update our beliefs about the world (e.g., model parameters) in response to 
possibly small amounts of new data is essential for creating intelligent agents that can adapt to 
changing distributions (see Chapter 19) and perform sequential decision making under uncertainty 
(see Part VI). 


3.1.4 Arguments against the Bayesian approach 


There are several arguments against the Bayesian approach. Some are philosophical, and focus on 
the sensitivity to the choice of prior (which we discuss in Section 3.5). However, in practice the 
main obstacle is computational. To see why, note that evaluating the posterior in Equation (3.1) 
can be computationally expensive, because of the need to compute the normalizing constant p(D) in 
the denominator, which often involves a high dimensional integral. In Part II, we will discuss many 
different algorithms for exact and approximate Bayesian computation. 


3.1.5 Why not just use MAP estimation? 


Bayesian inference uses a prior, which can help prevent overfitting. A simpler approach to this 
problem is to just compute the MAP estimate or maximum a posterior estimate, since this avoids 
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the need to compute p(D): 


6 = argmax log p(0|D) = argmax [log p(D|@) + log p(0)] (3.5) 
0 0 


Although this is often considered a Bayesian approach, since it is derived from a prior and a likelihood, 
it is not “fully Bayesian”, since it is just a point estimate. We discuss the downsides of MAP estimation 
below. 


3.1.5.1 The MAP estimate gives no measure of uncertainty 


In many statistical applications (especially in science) it is important to know how much one can 
trust a given parameter estimate. For example, consider tossing a coin. If we get Ni heads and No 
tails, we know that the maximum likelihood estimate is 


Ni Ny, 


Omie = Ni +No =W (3.6) 
where N = N, + No. But this is just a point estimate, with no associated uncertainty. For example, 
if we toss a coin 5 times, and see 4 heads, we estimate Onle = 4/5, but do we really believe the coin 
is that biased? We cannot be sure, because the sample size is so small. 

Now suppose we put a prior on @. As we explain in Section 3.2.1, it is convenient to use a beta 
distribution as a prior, p(0) = Beta(@| &, B), where & and Ø are known as pseudo counts. We will 
show that the posterior has the form p(0) = Beta(6| a, 8), where @=% +N, and B=8 +No. The 
mode of this distribution (i.e., the MAP estimate) is 


a 


a a—-l 


map — z a., Da 3.7 
P @-1+ 6-1 (3:7) 


Even though we used a prior, this is still just another point estimate, with no notion of uncertainty. 
However, we can derive measures of confidence or uncertainty from posterior distribution. For 
example, we can compute a 95% credible interval I = (4, u), which satisfies p(0 € I|D) = 0.95, by 
setting l = F~'(a/2) and u = F~!(1 — a/2), where F is the cdf of the posterior, and F~! is the 
inverse cdf. Alternatively, we can compute the posterior standard deviation (sometimes called the 
standard error), given by ø = yY [6|D]. See [Mur22, Sec 4.6.2.8] for details. 


3.1.5.2 The plugin approximation does not capture predictive uncertainty 


In machine learning applications, the parameters of our model are usually of little interest, since 
they are usually unidentifiable and hence uninterpretable. Instead, we are interested in predictive 
uncertainty. This can be useful in applications which involve decision making, such as reinforcement 
learning, active learning, or safety-critical applications. We can derive uncertainty in the predictions 
induced by uncertainty in the parameters by computing the posterior predictive distribution: 


plylæ, D) = / plylæ, @)p(8|D)a0 (3.8) 


By integrating out, or marginalizing out, the unknown parameters, we reduce the chance of 
overfitting, since we are effectively computing the weighted average of predictions from an infinite 
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© training data © training data 
75 —}— prediction 100 —~ prediction 
50 
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—5 0 5 
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Figure 8.1: Predictions made by a polynomial regression model fit to a small dataset. (a) Plugin approximation 
to predictive density using the MLE. The curves shows the posterior mean, E[y|a], and the error bars show the 
posterior standard deviation, std [y|a], around this mean. (b) Bayesian posterior predictive density, obtained 
by integrating out the parameters. Generated by linreg_post_ pred_plot.ipynb. 


number of models. This act of integrating over uncertainty is at the heart of the Bayesian approach 
to machine learning. (Of course, the Bayesian approach requires a prior, but so too do methods that 
rely on regularization, so the prior is not so much the distinguishing aspect.) 

Suppose we approximate the posterior by a point estimate. We can represent this using a delta 
function, p(@|D) ~ 5(@ — Ô); the value @ could the MLE or a MAP estimate. If we substitute this 
into Equation (3.8), and use the sifting property of delta functions, we get the following approximation 
to the posterior predictive: 


p(ylae, D) ~ I p(ylæ, 6)5(0 — 8)d0 = plylæ, 8) (3.9) 


This is called a plugin approximation, and is very widely used, due to its simplicity. 

However, the plugin approximation ignores uncertainty in the parameter estimates, which can result 
in an underestimate of the uncertainty. For example, Figure 3.la plots the plugin approximation 
p(y\a, @) for a linear regression model p(y|æ, 0) = .N(y|w' a, 02), where we plug in the MLEs for w 
and g? (the plot looks similar if we plug in the MAP estimates). We see that the size of the predicted 
variance is a constant (namely 67). 

The uncertainty captured by ø is called aleatoric uncertainty or intrinsic uncertainty, and 
would persist even if we knew the true model and true parameters. In practice we don’t know the 
parameters. This induces an additional, and orthogonal, source of uncertainty, called epistemic 
uncertainty (since it arises due to a lack of knowledge about the truth). In the Bayesian approach, 
we take this into account. The result is shown in Figure 3.1b. We see that now the error bars get 
wider as we move away from the training data; this is due to the Bayesian estimate of the parameters 
being adaptive to the test data, i.e., in the Bayesian approach, we predict using p(y|a,D) = 
N (y|Woayes(x)' £, GPayes) Whereas in the plugin approach, we predict using p(y|a,D) ~ p(y|æ, 6) = 
N (y|@hie® inte): 

For more details on Bayesian linear regression, see Section 15.2. For details on how to derive 
Bayesian predictions for nonlinear models such as neural nets, see Section 17.1. 
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4 0.5M (x|0, 2)-+ Gamma(1,1) distribution 
~ 0.5N (a2, 0.05 1.0 

= mean 0.8 

© 0.6 
& 0.4 
0.2 
0.0 


a=1.0, b=1.0 


(a) (b) 


Figure 3.2: Two distributions in which the mode (highest point) is untypical of the distribution; the mean 
(vertical red line) is a better summary. (a) A bimodal distribution. Generated by bimodal_ dist_ plot.ipynb. 
(b) A skewed Ga(1,1) distribution. Generated by gamma_ dist_ plot.ipynb. 


(0,0,0) (0,0,1) 
p2 p2 
(1,0,0) (1,0,1) 

p2 
(0,1,0) (0,1,1) 


Figure 3.8: A distribution on a discrete space in which the mode (black point L, with probability pı) is 
untypical of most of the probability mass (gray circles, with probability po < pi). The small black circle labeled 
M (near the top left) is the posterior mean, which is not well defined in a discrete state space. C (the top left 
vertex) is the centroid estimator, made up of the mazimizer of the posterior marginals. See text for details. 
From Figure 1 of [CLO7]. Used with kind permission of Luis Carvalho. 


3.1.5.3 The MAP estimate is often untypical of the posterior 


The MAP estimate is often easy to compute. However, the mode of a posterior distribution is often a 
very poor choice as a summary statistic, since the mode is usually quite untypical of the distribution, 
unlike the mean or median. This is illustrated in Figure 3.2(a) for a 1D continuous space, where we 
see that the mode is an isolated peak (black line), far from most of the probability mass. By contrast, 
the mean (red line) is near the middle of the distribution. 

Another example is shown in Figure 3.2(b): here the mode is 0, but the mean is non-zero. Such 
skewed distributions often arise when inferring variance parameters, especially in hierarchical models. 
In such cases the MAP estimate (and hence the MLE) is obviously a very bad estimate. 

Similar problems with MAP estimates can arise in discrete spaces, such as when estimating graph 
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structures, or long sequences of symbols. Figure 3.3 shows a distribution on {0,1}%, where points are 
arranged such that they are connected to their nearest neighbors, as measured by Hamming distance. 
The black state (circle) labeled L (configuration (1,1,1)) has probability pı; the 4 gray states have 
probability po < pı; and the 3 white states have probability 0. Although the black state is the most 
probable, it is untypical of the posterior: all its nearest neighbors have probability zero, meaning it 
is very isolated. By contrast, the gray states, although slightly less probable, are all connected to 
other gray states, and together they constitute much more of the total probability mass. 


3.1.5.4 The MAP estimate is only optimal for 0-1 loss 


The MAP estimate is the optimal estimate when the loss function is 0-1 loss, £(6, 6) =I Q # ô), as 
we show in Section 34.1.2. However, this does not give any “partial credit” for estimating some of the 
components of @ correctly. An alternative is to use the Hamming loss: £(0,ĝ) = Sa I (64 # 6a). 
In this case, one can show that the optimal estimator is the vector of max marginals 


D 
ĝ = segs | T (3.10) 


ba d=1 


This is also called the maximizer of posterior marginals or MPM estimate. Note that computing 
the max marginals involves marginalization and maximization, and thus depends on the whole 
distribution; this tends to be more robust than the MAP estimate [MMP87]. 


3.1.5.5 The MAP estimate is not invariant to reparameterization 


A more subtle problem with MAP estimation is that the result we get depends on how we parameterize 
the probability distribution, which is not very desirable. For example, when representing a Bernoulli 
distribution, we should be able to parameterize it in terms of probability of success, or in terms of 
the log-odds (logit), without that affecting our beliefs. 

For example, let ĉ = argmax, p,(x) be the MAP estimate for x. Now let y = f(x) be a 
transformation of x. In general it is not the case that 7 = argmax, p,(y) is given by f(%). For 
example, let x ~ N(6,1) and y = f(x), where f(x) = ae We can use the change of 


variables (Section 2.5.1) to conclude p,(y) = pe( f(y) |. Alternatively we can use a Monte 
Carlo approximation. The result is shown in Figure 2.12. We see that the original Gaussian for 
p(x) has become “squashed” by the sigmoid nonlinearity. In particular, we see that the mode of the 
transformed distribution is not equal to the transform of the original mode. 

We have seen that the MAP estimate depends on the parameterization. The MLE does not suffer 
from this since the likelihood is a function, not a probability density. Bayesian inference does not 
suffer from this problem either, since the change of measure is taken into account when integrating 
over the parameter space. 


3.1.5.6 MAP estimation cannot handle the cold-start problem 


As the amount of data increases, the posterior p(@|D) sometimes shrinks to a point, since the 
likelihood term p(D|@) dominates the fixed prior p(@). That is, p(@|D) —> 6g(@), where 8 is the MLE. 
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(We will illustrate this phenomenon in more detail later in this chapter.) Thus one may think that in 
the era of big data, we do not need to model posterior uncertainty. However, this assumes that the 
data is informative about all of the unknown variables. In many problems there is a long tail of 
data, in which a small number of items occur frequently, but most items occur rarely. Consequently 
there may be a lot of uncertainty about these rare items (see e.g., [Jor11]). 

For example, in a recommender system, we will always be faced with new users, about whom we 
know very little (the so-called cold start problem). Bayesian methods can help create personalized 
recommendations in such cases, by borrowing statistical strength from other similar users for whom 
we have more data (see Section 3.7). Similarly, when performing online learning and sequential 
decision making, an agent will often encounter new data that may not have been seen before (indeed, 
it may actively seek such novelty), so it may often be in a small data regime. For example, see the 
discussion of Bayesian optimization in Section 6.8 and bandits in Section 34.4. 


3.2 Conjugate priors for simple models 


In this section, we consider the problem of inferring the posterior for parameters of simple probability 
models. These will form the foundations for many more complex models. We will use priors that are 
conjugate to the likelihood. This is defined as follows: a prior p(@) € F is a conjugate prior for 
a likelihood function p(D|@) if the posterior is in the same parameterized family as the prior, i.e., 
p(@|D) € F. In other words, F is closed under Bayesian updating. If the family F corresponds to 
the exponential family (defined in Section 2.3), then the computations can be performed in closed 
form. In the sections below, we give some common examples of this framework, which we will use 
later in the book. 


3.2.1 The binomial model 


One of the simplest examples of conjugate Bayesian analysis is the beta-binomial model. This is 
covered in detail in Section 4.6.2 of the prequel to this book, [Mur22]. For completeness, we just 
state the results here without further discussion. 

For N trials, the binomial likelihood is 


p(D|0) = Bin(y|N, 0) = C) o2 (1 — a) N-¥ (3.11) 


If N = 1, this reduces to the Bernoulli likelihood. 
The conjugate prior is a beta distribution: 


p(0) = Beta(0]| &, 3) x 0*1 (1 — 9) (3.12) 


The values @ and B are known as pseudo counts, since they play a role analogous to the empirical 
counts N; and No. 
If we multiply the likelihood with the beta prior we get a beta posterior: 


pD) x 6%2(1 — 0) 9% -1(1 — 9) 8-1 (3.13) 
x Beta(6| & +N1, 8 +No) (3.14) 
= Beta(6| @, 2) (3.15) 
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where @20 +N, and G28 +Np are the parameters of the posterior. 
The posterior mean is a convex combination of the prior mean, m =a / Ñ (where N=@ + 8 is the 


prior strength), and the MLE: Ole = APL: 


ð +N: Ñ m+ Np: Ñ Np Noi 4 
a+N,+ 8 +No Np+ N Np+ Ñ Np+ N No 
(3.16) 


3 [AID] 


where À = X is the ratio of the prior to posterior equivalent sample size. So the weaker the prior, 
the smaller is A, and hence the closer the posterior mean is to the MLE. In particular, if we set the 
prior pseudocounts to 0, then À = 0, so the posterior mean becomes the MLE. 
The posterior mode is given by 
a+N,-1 a+N,-1 


ĝ = 8D) = = = is 3.17 
we PRE) = e S ey) 


If we set the prior pseudocounts to 1, the MAP estimate becomes the MLE, which can overfit. To 
avoid this, it is common to set the prior pseudo counts to 2, which encodes a weak prior that 0 = 0.5; 
the corresponding MAP estimate is then just like the MLE, except we add 1 to add the empirical 
counts before normalizing. This is called add one smoothing. 

The variance of the beta posterior is given by 


ap 


Y [ØD] = = = = E [9D —— 3.18 
W (@ + 8)2(@ + 8 +1) [Pip] @ (1+ @ + 8) = 
where @=@ +N; and B=B +No. If Np >a + B, this simplifies to 
Np N. ô — ô 
V [0D] ~x PEP 0 — Ca (3.19) 


Np? Np 


where Ê is the MLE. Hence the standard error is given by 


o= WWE] ~ fE (3.20) 
Np 


We see that the uncertainty goes down at a rate of 1/ JN. 
Finally, we can show that the marginal likelihood for the beta-binomial model is given by 


(D) = Ca Bia +Np1,B +Noo) B Ca B(@, 8) 
PY T ANo: B(@, 6) ~ \Np1) B&D) 


In the Bernoulli case, where N = 1, the first term can be omitted. 


(3.21) 


3.2.2 The multinomial model 


In this section, we generalize the results from Section 3.2.1 from binary variables (e.g., coins) to 
K-ary variables (e.g., dice). 
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Let y ~ Cat(@) be a discrete random variable drawn from a categorical distribution. The 
likelihood has the form 


N C C 

p(D\@) = Il Cat(yn|9) = [[ [20 =] [0r (3.22) 
n=1 n=1c=1 c=1 

where Ne = >>, 1 (yn = c). We can generalize this to the multinomial distribution by defining 


y ~ M(N,@), where N is the number of trials, and ye = Ne is the number of times value c is 
observed. The likelihood becomes 


C 
pun oie A ` es pe (3.23) 


This is the same as the categorical likelihood modulo a scaling factor. Going forward, we will work 
with the categorical model, for notational simplicity. 

The conjugate prior for a categorical distribution is the Dirichlet distribution, which we discussed 
in Section 2.2.8.1. We denote this by p(@) = Dir(@| &), where & is the vector of prior pseudo-counts. 
Often we use a symmetric Dirichlet prior of the form &,=a /K. In this case, we have E [0,] = 1/K, 
and Y [0}] = KAS: Thus we see that increasing the prior sample size & decreases the variance of 
the prior, which is equivalent to using a stronger prior. 

We can combine the multinomial likelihood and Dirichlet prior to compute the Dirichlet posterior, 
as follows: 


p(@|D) x p(D\@)Dir(@ 9x I a iW me (3.24) 
k 


x Dir(6| % +M,...,%% +N) = Dir(6| &) (3.25) 


where Q,=G; +N; are the parameters of the posterior. So we see that the posterior can be computed 
by adding the empirical counts to the prior counts. In particular, the posterior mode is given by 
^ a, —1 Ng+ Gy —1 
bk = -K = -K = (3.26) 
ye a, —1 Yoi Net Oy —1 
If we set az, = 1 we recover the MLE; if we set a, = 2, we recover the add-one smoothing estimate. 
The marginal likelihood for the Dirichlet-categorical model is given by the following: 


p(D) = ao (3.27) 
where 
B(a) = Hira P(e) (3.28) 


PO, Ok) 


Hence we can rewrite the above result in the following form, which is what is usually presented in 
the literature: 


7 TO Qk) T(Np;, + ar) 
PP) = oN, cy, ail Ton) (3.29) 


For more details on this model, see [Mur22, Sec 4.6.3]. 
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3.2.3 The univariate Gaussian model 


In this section, we derive the posterior p(u,o°|D) for a univariate Gaussian. For simplicity, we 
consider this in three steps: inferring just u, inferring just 07, and then inferring both. See Section 3.3 
for the multivariate case. 


3.2.3.1 Posterior of u given o? 


If ø? is a known constant, the likelihood for u has the form 
1 Z 
p(Dlu) X exp (-= (Yn ~~ n) (3.30) 
n=1 


One can show that the conjugate prior is another Gaussian, N (u| M, 7%). Applying Bayes’ rule for 
Gaussians (Equation (2.82)), we find that the corresponding posterior is given by 


P(u|D, 07) = N (u| M, 7) (3.31) 
1 o2 r 
+g NP+0 R 
eo a 2 ~2 
a ng (mm NJ o b NT = 
= = 3.33 
“= (a+) NF a92 tH? 492 ( ) 


where yê $ ae Yn is the empirical mean. 

This result is easier to understand if we work in terms of the precision parameters, which are 
just inverse variances. Specifically, let A = 1/o? be the observation precision, and X= 1/ 7? be the 
precision of the prior. We can then rewrite the posterior as follows: 


PUID, A) =N (ul R, A) 3.34) 
Y=X+NA (3.35) 
~  Ndy+ XH NX _ A 
H a S aa PN 


These equations are quite intuitive: the posterior precision ) is the prior precision \ plus N units of 
measurement precision A. Also, the posterior mean M is a convex combination of the empirical mean 
y and the prior mean Mm. This makes it clear that the posterior mean is a compromise between the 
empirical mean and the prior. If the prior is weak relative to the signal strength (X is small relative 
to A), we put more weight on the empirical mean. If the prior is strong relative to the signal strength 
(\ is large relative to A), we put more weight on the prior. This is illustrated in Figure 3.4. Note 
also that the posterior mean is written in terms of NAZ, so having N measurements each of precision 
A is like having one measurement with value % and precision NA. 

To gain further insight into these equations, consider the posterior after seeing a single data point 
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Prior with variance of 1 Prior with variance of 5 
0.6 0.6 5 
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seeni likelihood i \ 1"... likelihood A 
0.47 —+* posterior P : 0.44 — += posterior / Ra 
= Pe 
0.2 
0.0 
—5.0 —2.5 0.0 2.5 5.0 j 
T T 
(a) (b) 


Figure 3.4: Inferring the mean of a univariate Gaussian with known o°. (a) Using strong prior, p(u) = 
N(p0,1). (b) Using weak prior, p(w) = N (u|0,5). Generated by gauss_infer_ 1d.ipynb. 


y (so N = 1). Then the posterior mean can be written in the following equivalent ways: 


3 ÀL À 
T ERA 3.37 
u a (3.37) 
À 
=m D a M) (3.38) 
x E 
=y— zy ) (3.39) 


The first equation is a convex combination of the prior mean and the data. The second equation 
is the prior mean adjusted towards the data y. The third equation is the data adjusted towards 
the prior mean; this is called a shrinkage estimate. This is easier to see if we define the weight 
w= 4/%. Then we have 


m= y — w(y— M) = (1 — w)y +w M (3.40) 


Note that, for a Gaussian, the posterior mean and posterior mode are the same. Thus we can use 
the above equations to perform MAP estimation. 


3.2.3.2 Posterior of a? given p 


If u is a known constant, the likelihood for g? has the form 


No 
p(D\o?) x (a?)-N?/? exp (-2 ye = n”) (3.41) 


where we can no longer ignore the 1/(0°) term in front. The standard conjugate prior is the inverse 
Gamma distribution (Section 2.2.3.4), given by 


5“ b 
(ory t exp(——3) (3.42) 


2x y) 
IG(o*| @, 6) = T) 
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prior = IG(v = 0), true o° = 10 


Figure 3.5: Sequential updating of the posterior for o° starting from an uninformative prior. The data 


was generated from a Gaussian with known mean u = 5 and unknown variance o° = 10. Generated by 
gauss seq update sigma_ 1d.ipynb 


Multiplying the likelihood and the prior, we see that the posterior is also IG: 


p(o?|u, D) = IG(o?| @, b) (3.43) 
@ =% +N/2 (3.44) 

a, en, BR 
= p (8.45) 

n=1 


See Figure 3.5 for an illustration. 

One small annoyance with using the IG(@, 5) distribution is that the strength of the prior is 
encoded in both @ and §. Therefore, in the Bayesian statistics literature it is common to use an 
alternative parameterization of the IG distribution, known as the (scaled) inverse chi-squared 
distribution: 

Ù Dp 7? Ù 7? 


—2 2x ~2 PA 
o T°) =IG(o*|= 
X ( | 0, ) ( Iz 5) Jg? 


Here 7 (called the degrees of freedom or dof parameter) controls the strength of the prior, and 7? 
encodes the prior mean. With this prior, the posterior becomes 


) x (07)~”/?-1 exp(- >) (3.46) 


p(o|D, u) = x° (0°] 2, 7”) (3.47) 
P =% +Np (3.48) 

= 2 No L2 
pe = ET EA i u) (3.49) 


We see that the posterior dof P is the prior dof X plus Np, and the posterior sum of squares P??? is 


the prior sum of squares YF? plus the data sum of squares. 
3.2.3.3 Posterior of u and o°: conjugate prior 


Now suppose we want to infer both the mean and variance. The corresponding conjugate prior is the 
normal inverse Gamma: 


NIG(u, 02| M, K, ă, 5) N (u| M, 0?/ X) IG(o?| ă, 5) (3.50) 
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(a) NIX? (Ho = 0, ko = 1, vo = 1,09 = 1) (b) NIX? (uo = 0, ko = 5, vo = 1,0 = 1) 


Figure 3.6: The NIX? (u, o°|m, k, v,a?) distribution. m is the prior mean and k is how strongly we believe 
this; o? is the prior variance and v is how strongly we believe this. (a)m=0,4 =1,v 1,0? =1. Notice 
that the contour plot (underneath the surface) is shaped like a “squashed egg”. (b) We increase the strength of 
our belief in the mean by setting k = 5, so the distribution for u around m = 0 becomes narrower. Generated 
by nix_ plots.ipynb. 


However, it is common to use a reparameterization of this known as the normal inverse chi-squared 
or NIX distribution [Gel+14a, p67], which is defined by 


NIX? (,07| M, X, 9,7?) = N (u| m,07/ X) xX’? ¥,7*) (3.51) 
lay YF? + X (u— Mm)? 
x (=) +3)/2 exp ( ; (u ) ) (3.52) 
(or oO 


See Figure 3.6 for some plots. Along the p axis, the distribution is shaped like a Gaussian, and along 
the g? axis, the distribution is shaped like a y~; the contours of the joint density have a “squashed 
egg” appearance. Interestingly, we see that the contours for u are more peaked for small values of 
a’, which makes sense, since if the data is low variance, we will be able to estimate its mean more 
reliably. 

One can show (based on Section 3.3.3.3) that the posterior is given by 


plu, o’ |D) = NIX? (u, 0°| M, R, D, 7?) (3.53) 
SRNT 
aa hin : T (3.54) 
R =K +N (3.55) 
?=9 +N (3.56) 
N s2 
DF? =YP? +> oon J? -ṣ nae (M —7)? (3.57) 


The interpretation of this is as follows. For u, the posterior mean 77 is a convex combination of 
the prior mean ™ and the MLE 7; the strength of this posterior, R, is the prior strength % plus the 
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3.2. CONJUGATE PRIORS FOR SIMPLE MODELS 


number of data points N. For a”, we work instead with the sum of squares: the posterior sum of 
nn : : wae) N —\2 
squares, V7“, is the prior sum of squares YT* plus the data sum of squares, $` —1(Yn — y)“, plus 
a term due to the discrepancy between the prior mean M and the MLE y. The strength of this 
posterior, , is the prior strength 7 plus the number of data points N; 
The posterior marginal for g? is just 


p(o?[D) = I pli, o?°|D)du = x~2(0?| 2, 2) (3.58) 


with the posterior mean given by E [o?|D] = 5%, 7°. 


The posterior marginal for u has a Student distribution, which follows from the fact that the 
Student distribution is a (scaled) mixture of Gaussians: 


p(u[D) = I p(u,0?|D)do? = T (ul , ?? / R,P) (3.59) 


with the posterior mean given by E [u|D] =m. 


3.2.3.4 Posterior of u and o°: uninformative prior 


If we “know nothing” about the parameters a priori, we can use an uniformative prior. We discuss how 
to create such priors in Section 3.6. A common approach is to use a Jeffreys prior. In Section 3.6.2.3, 
we show that the Jeffreys prior for a location and scale parameter has the form 


plu, 07) x p(u)p(a*) x o~? (3.60) 
We can simulate this with a conjugate prior by using 


plu, 07) = NIX? (u, 07| m 0, X 0,7 1,7? 0) (3.61) 


With this prior, the posterior has the form 


p(p, o? |D) = NIx?(p,07| m= 7, R= N, P= N —1,77= 8°) (3.62) 
where 
N 
1 N 
2A =\2 a2 
A a a-g = 2 3.63 
Tann (3.63) 


s is known as the sample standard deviation. Hence the marginal posterior for the mean is given 
by 


2 


= U s = = = Sia =o) 
P(u|D) = T (uly, yet 1) a T(uly, N(N — 1) ,N 1) (3.64) 


Thus the posterior variance of p is 


D N —-138? s? 
V [uD] = 5 i 


72 TNN N (3.65) 
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The square root of this is called the standard error of the mean: 


se(u) ê /V [uD] ~ Ti (3.66) 


Thus we can approximate the 95% credible interval for u using 


Los(u|D) = 9+ 2 (3.67) 


3.3 Conjugate priors for the multivariate Gaussian 


In this section, we derive the posterior p(s, &|D) for a multivariate Gaussian. For simplicity, we 
consider this in three steps: inferring just u, inferring just X, and then inferring both. 


3.3.1 Posterior of uw given X 
The likelihood has the form 
S? 1 
plu) = Ne, 5) (3.68) 
For simplicity, we will use a conjugate prior, which in this case is a Gaussian. In particular, 


if p(w) = N (u| ™,V) then we can derive a Gaussian posterior for yz based on the results in 
Main Section 2.2.6.2 We get 


p(w|D, £) = N (u| M, V) (3.69) 
Q =%  4ies> (3.70) 
M =V (=1(Npy)+ Vm) (3.71) 


Figure 3.7 gives a 2d example of these results. 


3.3.2 Posterior of © given pu 
We now discuss how to compute p(|D, p). 


3.3.2.1 Likelihood 


We can rewrite the likelihood as follows: 


ND 1 
pDl, E) x EI? exp (—Ft018,2-")) (3.72) 
where 
N 
Su = So Ga = H) (Yn — pu)" (3.73) 
n=l 


is the scatter matrix around p. 
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3.3. CONJUGATE PRIORS FOR THE MULTIVARIATE GAUSSIAN 


Data Prior Posterior after 10 points 
1 1 1.59 1 
@ @ 18.13 
ee o 
ee ei 12.09 
& 0 @ & 0 & 0 
e@ 0.53 6.04 
—1 0.00 —1 0.00 
—1 0 1 —1 0 1 
yı zı zı 
(a) (b) (c) 


Figure 3.7: Illustration of Bayesian inference for a 2d Gaussian random vector z. (a) The data is generated 
from yn ~ N(z, Sy), where z = [0.5,0.5]" and ©, = 0.1((2,1;1,1]). We assume the sensor noise covariance 
xy is known but z is unknown. The black cross represents z. (b) The prior is p(z) = N(z|0,0.112). (c) We 
show the posterior after 10 data points have been observed. Generated by gauss_infer_ 2d.ipynb. 


3.3.2.2 Prior 


The conjugate prior is known as the inverse Wishart distribution, which is a distribution over 
positive definite matrices, as we explained in Section 2.2.8.5. This has the following pdf: 


v 1 = 
IW(S| Ñ, 7) x |B| +P+1)/2 exp (-S0 =) (3.74) 


Here Y> D — 1 is the degrees of freedom (dof), and W is a symmetric pd matrix. We see that Ý 
plays the role of the prior scatter matrix, and Nọ £% +D + 1 controls the strength of the prior, and 
hence plays a role analogous to the sample size Np. 


3.3.2.3 Posterior 


Multiplying the likelihood and prior we find that the posterior is also inverse Wishart: 


1 
p(B|D, p) x |B|- exp (-57"'8,)) |E (7+ +0)/2 


exp (een %) (3.75) 
Np+(Y+D+1 1 v 

ser a (-5¢ [E184 ®))) (3.76) 

=IW(3| 8,7) (3.77) 

D =% +Np (3.78) 

v= +S, (3.79) 


In words, this says that the posterior strength is the prior strength X plus the number of observations 
Np, and the posterior scatter matrix W is the prior scatter matrix W plus the data scatter matrix 
Su. 
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3.3.3 Posterior of © and p 


In this section, we compute p(u, X|D) using a conjugate prior. 


3.3.3.1 Likelihood 
The likelihood is given by 


No 
_Np 1 z 
p(D\u, X) x |E exp (-; X (un — H) E (yn — w) (3.80) 
n=1 
One can show that 
No 
X(n- BE (yn - w) = (D'S) + N (g - wD" p) (3.81) 
n=1 
where 
N 
S Ê Sy = S (yn -D(yn—9)' =Y'CnY (3.82) 
n=1 


is empirical scatter matrix, and Cy is the centering matrix 


1 
Cy êIy — W ININ (3.83) 


Hence we can rewrite the likelihood as follows: 
_Np N: Ta = 1 _ 
POM D) x E exp (22 ug)" B)) exp (—Ftr'B-'8)) (3.84) 
We will use this form below. 


3.3.3.2 Prior 


The obvious prior to use is the following 
plu, ©) = N (ul m, FIWE G, 7) (3.85) 


where IW is the inverse Wishart distribution. Unfortunately, 4 and X appear together in a non- 
factorized way in the likelihood in Equation (3.84) (see the first exponent term), so the factored prior 
in Equation (3.85) is not conjugate to the likelihood.” 

The above prior is sometimes called conditionally conjugate, since both conditionals, p(j4|5) and 
p(&|), are individually conjugate. To create a fully conjugate prior, we need to use a prior where p 
and X are dependent on each other. We will use a joint distribution of the form p(w, X) = p(y) 2) p(B). 


2. Using the language of directed graphical models, we see that u and X become dependent when conditioned on D 
due to explaining away. See Figure 3.8(a). 
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3.3. CONJUGATE PRIORS FOR THE MULTIVARIATE GAUSSIAN 


V wv v m K ww v 


T a au 
HO Ox H >> 


OQ O 


ba NZ 


Ti Ti 


(a) (b) 


Os 


Figure 3.8: Graphical models representing different kinds of assumptions about the parameter priors. (a) A 
semi-conjugate prior for a Gaussian. (b) A conjugate prior for a Gaussian. 


Looking at the form of the likelihood equation, Equation (3.84), we see that a natural conjugate 
prior has the form of a Normal-inverse-Wishart or NIW distribution, defined as follows: 


is 1 Ss 
NIW(p, E| Mm, X, X, Ý) = N(p| M, zE) x IW(£]| Ý, 7) (3.86) 
1 e K ~\T 1 o 
= |X|" 2 exp | -5 (u M) X (u— m) 
ZNIW 2 
1 7 
x |Z] exp (zre %) (3.87) 


where the normalization constant is given by 
Zyrw Ê 2” PPT pÙ [Y (2r/ K)P/?| Ý 7/2 (3.88) 


The parameters of the NIW can be interpreted as follows: ™ is our prior mean for yz, and X is how 
strongly we believe this prior; W is (proportional to) our prior mean for X, and 7 is how strongly we 
believe this prior.” 


3.3.3.3 Posterior 


To derive the posterior, let us first rewrite the scatter matrix as follows: 
T LS > T T T 
S=Y'Y- NOD, WÈ n) =Y'Y - Noy (3.89) 
where YTY = D Yny, is the sum of squares matrix. 


3. Note that our uncertainty in the mean is proportional to the covariance. In particular, if we believe that the variance 
is large, then our uncertainty in y must be large too. This makes sense intuitively, since if the data has large spread, it 
will be hard to pin down its mean. 
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Now we can multiply the likelihood and the prior to give 


Dm EID) x E exp (- 2u- u —B)) exp (Furs) ) (3.90) 
x |Z exp (-Su- m)' =! (p— m) exp (jeen %) (3.91) 
= |E t 7+ 42)/2 exp(—5tn('M)) (3.92) 
where 
Mê N(u -y(u -7+ X (pw ™)(u- M) + 84+ G (3.93) 
= (X +N)up' — p(Km +NJ) — (Km +NJ)u + XMM +YTY+ Ù (3.94) 


We can simplify the M matrix using a trick called completing the square Applying this to the 
above, we have 


($ +N)uu' — (žm +Ny)' — (Km +NJ)u' (3.95) 
ae = we =N TF 
C kM +NYJ km +Ny 
= N . 
(K+ (n ZIN ) (u ZIN ) (3.96) 
K +N ; 
=R (u— P) (u— M) — RMM (3.98) 


5 1 A 
plu, EID) x ||P +2+2/2 exp (-3" [=" G (u— ®)(u— m)T+ ®)]) (3.99) 
= NIW (u, E| M,R, D, D) (3.100) 
where 

A Km +Npy K Np _ 
= = .1 1 
ii R KEN REND” om 
R =% +Np (3.102) 
D =% +Np (3.103) 

# i a RT 

W=W +S + X +Np (y— Mm)(y-— m) (3.104) 
=W +YTY + Kmm' —Rmm' (3.105) 


This result is actually quite intuitive: the posterior mean ™ is a convex combination of the prior 
mean and the MLE; the posterior scatter matrix W is the prior scatter matrix w plus the empirical 
scatter matrix S plus an extra term due to the uncertainty in the mean (which creates its own virtual 
scatter matrix); and the posterior confidence factors R and V are both incremented by the size of the 
data we condition on. 
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3.3. CONJUGATE PRIORS FOR THE MULTIVARIATE GAUSSIAN 


3.3.3.4 Posterior marginals 


We have computed the joint posterior 
z I ea a 
p(n, EID) = N(u|B,D)p(B|D) = N (ul *, =D) IW (| È, 2) (3.106) 


We now discuss how to compute the posterior marginals, p(X|D) and p(u|D). 
It is easy to see that the posterior marginal for ®© is 


p(Z|D) = | p(n, BIP)du = IWE] @, 9) (3.107) 


For the mean, one can show that 


a aw 
p(e|D) = [ow IID) = T (H| R, zgo? ) (3.108) 


where D'£p —D +1. Intuitively this result follows because p(u|D) is an infinite mixture of Gaussians, 
where each mixture component has a value of © drawn from the IW distribution; by mixing these 
altogether, we induce a Student distribution, which has heavier tails than a single Gaussian. 


3.3.3.5 Posterior mode 


The maximum a posteriori (MAP) estimate of u and X is the mode of the posterior NIW distribution 
with density 


plu, ENY) = N (u| A,R! S\IW(S| 2, W) (3.109) 


To find the mode, we firstly notice that u only appears in the conditional distribution N (u| f, R7! 5), 
and the mode of this normal distribution equals its mean, i.e., =p. Also notice that this holds for 
any choice of X. So we can plug u =f in Equation (3.109) and derive the mode of X. Notice that 


—2 x log p(w =f, E|Y) = (? +D + 2) log(|¥|) + trace(W £7!) + c (3.110) 


where c is a constant irrelevant to X. We then take the derivative over X: 


Jlog p(w =A DY) 
ax 


=(94D42)>1-y>71¢571 (3.111) 


By setting the derivative to 0 and solve for ©, we see that (? +D + 2)7! Ẹ 

is the matrix that maximizes Equation (3.110). By checking that W is a positive definite matrix, 
we conclude that W is the MAP estimate of the covariance matrix X. 

In conclusion, the MAP estimate of {u, X} are 


__ K+Ny 
aa 112 
H=- XIN ere) 
a 1 A 
$=-—— 11 
D+D42 git) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


l= IO lœ IN Im o e Jw I e 


m. 
m. 


N j j= j= j= je je j= je Je 
IRIS Ie le Sls lali le Is | 


86 


3.3.3.6 Posterior predictive 


We now discuss how to predict future data by integrating out the parameters. If y ~ N (u, ©), where 


Da 


(u, EID) ~ NIW(M,R, D, Y), then one can show that the posterior predictive distribution, for a 
single observation vector, is as follows: 


p(ylD) = I N (ælu, E)NIW (u, £| M, R, P, B)dpdd (3.114) 


where D’=) —D +1. 


3.4 Conjugate priors for the exponential family 


We have seen that exact Bayesian analysis is considerably simplified if the prior is conjugate to the 
likelihood. Since the posterior must have the same form as the prior, and hence the same number of 
parameters, the likelihood function must have fixed-sized sufficient statistics, so that we can write 
p(D|@) = p(s(D)|@). This suggests that the only family of distributions for which conjugate priors 
exist is the exponential family, a result proved in [DY79].* In the sections below, we show how to 
perform conjugate analysis for a generic exponential family model. 


3.4.0.1 Likelihood 


Recall that the likelihood of the exponential family is given by 
p(D\n) = h(D) exp(n" s(D) — NpA(n)) (3.116) 


where s(D) = D; s(an) and h(D) £ JJA; h(a). 


n=l 


3.4.0.2 Prior 


We will write the prior in a form that mirrors the likelihood: 


pinl #7) = zy P n- 7 Alm) (3.117) 


1 
Z(T,0 


where Ñ is the strength of the prior, and 7 / ï is the prior mean, and Z(*, 7) is a normalizing factor. 
The parameters 7 can be derived from virtual samples representing our prior beliefs. 


4. There are some exceptions. For example, the uniform distribution Unif(z|0,0) has finite sufficient statistics 
(N,m = max; xi), as discussed in Section 2.3.2.6; hence this distribution has a conjugate prior, namely the Pareto 
distribution (Section 2.2.3.5), p(@) = Pareto(6|00,«), yielding the posterior p(6|a) = Pareto(max(09,m),« + N). 
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3.4. CONJUGATE PRIORS FOR THE EXPONENTIAL FAMILY 


3.4.0.3 Posterior 


The posterior is given by 


p(n|D) = eee (3.118) 

__ AD) 2 n 
=E p 7? nE A]) (3.119) 

1 A A 
= FE (#" n- 7 A(n)) (3.120) 
where 
7 =7 +8(D) (3.121) 
P =% +Np (3.122) 
zae 2%”) D) (3.123) 
= aD)” . 


We see that this has the same form as the prior, but where we update the sufficient statistics and the 
sample size. We also see that the marginal likelihood is given by 


(3.124) 


The posterior mean is given by a convex combination of the prior mean and the empirical mean 
(which is the MLE): 


| TT +s(D) 0 F N s(D) 
NP > yan, Pena eG N er) 
= AE [n] + (1 — A) mie (3.126) 


_ y 
where à = YEN" 


3.4.0.4 Posterior predictive density 


We will now derive the predictive density for future observables D’ = (@1,...,%N,’) given past data 


De Gece): 
p(D'|D) = | p(D'\n)p(nD)an (3.127) 
= [ MD) expla sD) — N'AM) yrs el è- P Alm) an (3.128) 
_ pip 2 t8(D) + (D), Y +N + N') Gi 


Z(¥ +s(D),¥ +N) 
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3.4.0.5 Example: Bernoulli distribution 


As a simple example, let us revisit the Beta-Bernoulli model in our new notation. 
The likelihood is given by 


p(D|9) = (1 — 0)N? exp Gre 9d z) (3.130) 


Hence the conjugate prior is given by 


p(O|vo, To) x (1 — 0)” exp (eG £ 7%) (3.131) 
= 97(1 — 0) (3.132) 


If we define a = To + 1 and 8 = vo — To +1, we see that this is a beta distribution. 
We can derive the posterior as follows, where s = $; I (x; = 1) is the sufficient statistic: 


p(O|D) x 0+5 (1 — Q)Yo—ToT Ns (3.133) 
= 97 (1 — 0)” (3.134) 


We can derive the posterior predictive distribution as follows. Assume p(@) = Beta(6\a, 8), and 
let s = s(D) be the number of heads in the past data. We can predict the probability of a given 


sequence of future heads, D' = (%,...,%,), with sufficient statistic s’ = Jv; 1(@; = 1), as follows: 
1 
p(D'|D) = J p(D'|0|Beta(Olan, Bn)dO (3.135) 
0 
lant bn) f° = w- 
= ef gt — 9) batt dg 3.136 
Na Jp ) (3.136) 
T n T Hn T n+m n+m 
_ Dlan + Bn) P(antm)I (Bnrm) E 
Tan T (8 n) Dias + Baten) 
where 
Anjm =Ant+s’ =atst+s' (3.138) 
Bntm = bn + (m — s')=B+(n—s)+(m—s’) (3.139) 


3.5 Beyond conjugate priors 


In Section 3.2, we saw various examples of conjugate priors, all of which have come from the 
exponential family (see Section 2.3). These priors have the advantage of being easy to interpret 
(in terms of sufficient statistics from a virtual prior dataset), and easy to compute with. However, 
for most models, there is no prior in the exponential family that is conjugate to the likelihood. 
Furthermore, even where there is a conjugate prior, the assumption of conjugacy may be too limiting. 
Therefore in the sections below, we briefly discuss various other kinds of priors. (We defer the 
question of posterior inference with these priors until Section 7.1, where we discuss algorithmic issues, 
since we can no longer use closed-form solutions when the prior is not conjugate.) 
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3.5. BEYOND CONJUGATE PRIORS 


3.5.1 Robust (heavy-tailed) priors 


The assessment of the influence of the prior on the posterior is called sensitivity analysis, or 
robustness analysis. There are many ways to create robust priors. (see e.g., [[R00]). Here we 
consider a simple approach, namely the use of a heavy-tailed distribution. 

To motivate this, let us consider an example from |Ber85, p7]. Suppose x ~ N (0,1). We observe 
that z = 5 and we want to estimate 0. The MLE is of course 6 = 5, which seems reasonable. The 
posterior mean under a uniform prior is also @ = 5. But now suppose we know that the prior median 
is 0, and that there is 25% probability that @ lies in any of the intervals (—co, —1), (—1,0), (0,1), 
(1,00). Let us also assume the prior is smooth and unimodal. 

One can show that that a Gaussian prior of the form M (0|0, 2.192) satisfies these prior constraints. 
But in this case the posterior mean is given by 3.43, which doesn’t seem very satisfactory. An 
alternative distribution that captures the same prior information is the Cauchy prior 7;(0|0, 1). With 
this prior, we find (using numerical method integration: see robust_ prior_demo.ipynb for the code) 
that the posterior mean is about 4.6, which seems much more reasonable. In general, priors with 
heavy tails tend to give results which are more sensitive to the data, which is usually what we desire. 

Heavy-tailed priors are usually not conjugate. However, we can often approximate a heavy-tailed 
prior by using a (possibly infinite) mixture of conjugate priors. For example, in Section 28.2.3, we 
show that the Student distribution (of which the Cauchy is a special case) can be written as an 
infinite mixture of Gaussians, where the mixing weights come from a Gamma distribution. This is 
an example of a hierarchical prior; see Section 3.7 for details. 


3.5.2 Priors for variance parameters 


In this section, we discuss some commonly used priors for variance parameters. Such priors play 
an important role in determining how much regularization a model exhibits. For example, consider 
a linear regression model, p(y|x, w,o7) = N(y|w'a, 07). Suppose we use a Gaussian prior on the 
weights, p(w) = .N(w|0,77I). The value of 7? (relative to g?) plays a role similar to the strength 
of an f9-regularization term in ridge regression. In the Bayesian setting, we need to ensure we use 
sensible priors for the variance parameters, T? and a”. This becomes even more important when we 
discuss hierarchical models, in Section 3.7. 


3.5.2.1 Priors for scalar variances 


Consider trying to infer a variance parameter g? from a Gaussian likelihood with known mean, as in 
Section 3.2.3.2. The uninformative prior is p(o?) = IG(o?|0,0), which is improper, meaning it does 
not integrate to 1. This is fine as long as the posterior is proper. This will be the case if the prior is 
on the variance of the noise of N > 2 observable variables. Unfortunately the posterior is not proper, 
even if N — oo, if we use this prior for the variance of the (non observable) weights in a regression 
model [Gel06; PS12], as we discuss in Section 3.7. 

One solution to this is to use a weakly informative proper prior such as IG(e,¢) for small e. 
However, this turns out to not work very well, for reasons that are explained in [Gel06; PS12]. Instead, 
it is recommended to use other priors, such as uniform, exponential, half-normal, half-Student-t, or 
half-Cauchy; all of these are bounded below by 0, and just require 1 or 2 hyperparameters. (The 
term “half” refers to the fact that the distribution is “folded over” onto itself on the positive side of 
the real axis.) 
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3.5.2.2 Priors for covariance matrices 


The conjugate prior for a covariance matrix is the inverse Wishart (Section 2.2.8.6). However, it can 
be hard to set the parameters for this in an uninformative way. One approach, discussed in [HW13], 
is to use a scale mixture of inverse Wisharts, where the scaling parameters have inverse gamma 
distributions. It is possible to choose shape and scale parameters to ensure that all the correlation 
parameters have uniform (—1,1) marginals, and all the standard deviations have half-Student 
distributions. 

Unfortunately, the Wishart distribution has heavy tails, which can lead to poor performance 
when used in a sampling algorithm.” A more common approach, following Equation (3.140), is to 
represent the D x D covariance matrix X in terms of a product of the marginal standard deviations, 
o =(01,...,0p), and the D x D correlation matrix R, as follows: 


X = diag(o) R diag(c) (3.140) 


For example, if D = 2 ,we have 


fan 0 1 p\(m 0\_ [ of pmo 
a C 2) f d A — es o2 (3141) 


We can put a factored prior on the standard deviations, following the recommendations of Section 3.5.2. 
For example, 


D 
pla) = | | Expon(oql1) (3.142) 
d=1 
For the correlation matrix, it is common to use as a prior the LKJ distribution, named after the 
authors of [LKJ09]. This has the form 


LKJ(R|n) x |R|" (3.143) 


so it only has one free parameter. When ņ = 1, it is a uniform prior; when 7 = 2, it is a “weakly 
regularizing” prior, that encourages small correlations (close to 0). See Figure 3.9 for a plot. 

In practice, it is more common to define R in terms of its Cholesky decomposition, R = LL", 
where L is an unconstrained lower triangular matrix. We then represent the prior using 


LKJchol(L|7) « |L] 77t (3.144) 


3.6 Noninformative priors 


When we have little or no domain specific knowledge, it is desirable to use an uninformative, 
noninformative or objective priors, to “let the data speak for itself”. Unfortunately, there is no 
unique way to define such priors, and they all encode some kind of knowledge. It is therefore better 
to use the term diffuse prior, minimally informative prior or default prior. 

In the sections below, we briefly mention some common approaches for creating default priors. For 
further details, see e.g., [KW96] and the Stan website.’ 


5. See comments from Michael Betancourt at https: //github.com/pymc-devs/pymc/issues/538. 
6. https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations. 
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LKJ 1D Correlation Coef. 


Density 


Correlation 


Figure 3.9: Distribution on the correlation coefficient p induced by a 2d LKJ distribution with varying 
parameter. Adapted from Figure 14.8 of [McE20]. Generated by lkj_ 1d.ipynb. 


3.6.1 Maximum entropy priors 


A natural way to define an uninformative prior is to use one that has maximum entropy, since 
it makes the least commitments to any particular value in the state space (see Section 5.2 for a 
discussion of entropy). This is a formalization of Laplace’s principle of insufficient reason, in 
which he argued that if there is no reason to prefer one prior over another, we should pick a “flat” 
one. 

For example, in the case of a Bernoulli distribution with rate 0 € [0,1], the maximum entropy 
prior is the uniform distribution, p(0) = Beta(@|1, 1), which makes intuitive sense. 

However, in some cases we know something about our random variable 0, and we would like our 
prior to match these constraints, but otherwise be maximally entropic. More precisely, suppose 
we want to find a distribution p(@) with maximum entropy, subject to the constraints that the 
expected values of certain features or functions f;,(@) match some known quantities Fẹ. This is called 
a maxent prior. In Section 2.3.7, we show that such distributions must belong to the exponential 
family (Section 2.3). 

For example, suppose 0 € {1,2,...,10}, and let pe = p(@ = c) be the corresponding prior. Suppose 
we know that the prior mean is 1.5. We can encode this using the following constraint 


E[fi(@)] =E] =X cpe =1.5 (3.145) 


In addition, we have the constraint `, pe = 1. Thus we need to solve the following optimization 
problem: 


inH t. ¿= 15, c= 1.0 3.146 
min (p) s dP dP ( ) 


This gives the decaying exponential curve in Figure 3.10. Now suppose we know that @ is either 3 or 
4 with probability 0.8. We can encode this using 


E [fi(0)] = E [I (8 € {3, 4})] = Pr(0 € {3,4}) = 0.8 (3.147) 


This gives the inverted U-curve in Figure 3.10. We note that this distribution is flat in as many 
places as possible. 
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Figure 3.10: Illustration of 3 different marimum entropy priors. Adapted from Figure 1.10 of [MKL11]. 
Generated by mazent_ priors.ipynb. 


3.6.2 Jeffreys priors 


Let 0 be a random variable with prior pọ(0), and let ¢ = f (0) be some invertible transformation of 
0. We want to choose a prior that is invariant to this function f, so that the posterior does not 
depend on how we parameterize the model. 

For example, consider a Bernoulli distribution with rate parameter 9. Suppose Alice uses a binomial 
likelihood with data D, and computes p(6|D). Now suppose Bob uses the same likelihood and data, 
but parameterizes the model in terms of the odds parameter, ¢ = oy. He converts Alice’s prior to 
p(@) using the change of variables formula, and them computes p(¢|D). If he then converts back to 
the 6 parameterization, he should get the same result as Alice. 

We can achieve this goal that provided we use a Jeffreys prior, named after Harold Jeffreys.” 
In 1d, the Jeffreys prior is given by p(0) x y F (0), where F is the Fisher information (Section 2.4). 
In multiple dimensions, the Jeffreys prior has the form p(@) x ,/det F(@), where F is the Fisher 
information matrix (Section 2.4). 


To see why the Jeffreys prior is invariant to parameterization, consider the 1d case. Suppose 


7. Harold Jeffreys, 1891 — 1989, was an English mathematician, statistician, geophysicist, and astronomer. He is not 
to be confused with Richard Jeffrey, a philosopher who advocated the subjective interpretation of probability [Jef04]. 
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pe(@) x ./F (0). Using the change of variables, we can derive the corresponding prior for ¢ as follows: 


pol) =m | (3.148) 
x 4/ F(A) ey =,|E (see) (2) (3.149) 
_ | (es) _ | (ae) aw 
= /F() (3.151) 


Thus the prior distribution is the same whether we use the 0 parameterization or the ¢@ parameteriza- 
tion. 
We give some examples of Jeffreys priors below. 


3.6.2.1 Jeffreys prior for binomial distribution 
Let us derive the Jeffreys prior for the binomial distribution using the rate parameterization 0. From 


Equation (2.258), we have 


1 1 1 
p(0) x 6-2 (1 — 0)" 2 = ————— x Beta(0| 


1 
Waa = (3.152) 


NI = 


Now consider the odds parameterization, ¢ = 0/(1 — 0), so 0 = =. The likelihood becomes 


Bers (5) ( 5) O =Ø = SHE (3.153) 


Thus the log likelihood is 
L= xrlogġ — nlogġ +1 (3.154) 


The first and second derivatives are 


a 
ON 
8 
3 


oe x n 
Since E |z] = nð = n$, the Fisher information matrix is given by 
ae n n 
Po =- [Fa] agen GP el 
oe ae (3.158) 


pl + 1)? plo + 1)? 
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Jeffreys' prior for Alice Jeffreys' prior for Bob 


Jeffreys' posterior for Alice Jeffreys' posterior for Bob 
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Figure 3.11: Illustration of Jeffrey’s prior for Alice (who uses the rate 0) and Bob (who uses the odds 
p =0/(1—0)). Adapted from Figure 1.9 of [MKL11]. Generated by jeffreys_ prior_ binomial.ipynb. 


Hence 


pelo) x @ °°(1+ o)! (3.159) 


See Figure 3.11 for an illustration. 


3.6.2.2 Jeffreys prior for multinomial distribution 


For a categorical random variable with K states, one can show that the Jeffreys prior is given by 


1 1 
p(O) x Dir(9|5, bets 3) (3.160) 
Note that this is different from the more obvious choices of Dir(#,...,) or Dir(1,..., 1). 
3.6.2.3 Jeffreys prior for the mean and variance of a univariate Gaussian 
Consider a 1d Gaussian x ~ N(,07) with both parameters unknown, so 0 = (,0). From 


Equation (2.263), the Fisher information matrix is 
_ fi\/o? 0 


so ,/det(F(@)) = 4/&. However, the standard Jeffreys uninformative prior for the Gaussian is 
defined as the product of independent uninformative priors (see [KW96]), i.e., 


plu, o?) x plu)plo’) x 1/0? (3.162) 
It turns out that we can emulate this prior with a conjugate NIX prior: 
plu, o?) = NIX? (u, o°|uo = 0, X= 0, Y= —1, F?= 0) (3.163) 


This lets us easily reuse the results for conjugate analysis of the Gaussian in Section 3.2.3.3, as we 
showed in Section 3.2.3.4. 
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3.6.3 Invariant priors 


If we have “objective” prior knowledge about a problem in the form of invariances, we may be able to 
encode this into a prior, as we show below. 


3.6.3.1 Translation-invariant priors 


A location-scale family is a family of probability distributions parameterized by a location u and 
scale ø. If x is an rv in this family, then y = a + ba is also an rv in the same family. 

When inferring the location parameter p, it is intuitively reasonable to want to use a translation- 
invariant prior, which satisfies the property that the probability mass assigned to any interval, 
[A, B] is the same as that assigned to any other shifted interval of the same width, such as [A—c, B—c]. 
That is, 


B-c B 
1 plu)du = / p(w )du (3.164) 


A-c A 


This can be achieved using 


plu) «1 (3.165) 
B-c B 
| ldu = (B—c)—(A-c) =(B-A)= 1 ldu (3.166) 
A-c A 


This is the same as the Jeffreys prior for a Gaussian with unknown mean p and fixed variance. 
This follows since F(u) = 1/0? œ 1, from Equation (2.263), and hence p(y) œ 1. 


3.6.3.2 Scale-invariant prior 


When inferring the scale parameter g, we may want to use a scale-invariant prior, which satisfies 
the property that the probability mass assigned to any interval [A, B] is the same as that assigned to 
any other interval [A/c, B/c], where c > 0. That is, 


B/c B 
I plo)do = | p(a)do (3.167) 
A/c A 
This can be achieved by using 

plo) x 1/0 (3.168) 


since then 


B/c B 
/ Lag = [log alae = log(B/c) — log(A/c) = log(B) — log(A) = | tie (3.169) 
A/c F Ao 


This is the same as the Jeffreys prior for a Gaussian with fixed mean u and unknown scale ø. This 
follows since F(c) = 2/0?, from Equation (2.263), and hence p(o) « 1/c. 
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3.6.3.3 Learning invariant priors 


Whenever we have knowledge of some kind of invariance we want our model to satisfy, we can use 
this to encode a corresponding prior. Sometimes this is done analytically (see e.g., [Rob07, Ch.9]). 
When this is intractable, it may be be possible to learn invariant priors by solving a variational 
optimization problem (see e.g., [NS18]). 


3.6.4 Reference priors 


One way to define a noninformative prior is as a distribution which is maximally far from all possible 
posteriors, when averaged over datasets. This is the basic idea behind a reference prior |Ber05; 
BBS09]. More precisely, we say that p(@) is a reference prior if it maximizes the expected KL 
divergence between posterior and prior: 


p*(0) = argmax | p(D) Dx (p(9|D) || p(@)) dD (3.170) 
p(@) JD 


where p(D) = f p(D|@)p(@)d@. This is the same as maximizing the mutual information 1(0,D). 
We can eliminate the integral over datasets by noting that 


[om f 2010) 08 a = f (0) f oOo) 108 on = Eo [Dx (p(DI6) || p(D))} (3.171) 


where we used the fact that ae =? aM, 

One can show that, in 1d, the corresponding prior is equivalent to the Jeffreys prior. In higher 
dimensions, we can compute the reference prior for one parameter at a time, using the chain rule. 
However, this can become computationally intractable. See [NS17] for a tractable approximation 
based on variational inference (Section 10.1). 


3.7 Hierarchical priors 


Bayesian models require specifying a prior p(0) for the parameters. The parameters of the prior are 
called hyperparameters, and will be denoted by @. If these are unknown, we can put a prior on 
them; this defines a hierarchical Bayesian model, or multi-level model, which can visualize 
like this: @ > 0 + D. We assume the prior on the hyper-parameters is fixed (e.g., we may use some 
kind of minimally informative prior), so the joint distribution has the form 


ple, 8,D) = p(b)p(4|¢)p(P|@) (3.172) 


The hope is that we can learn the hyperparameters by treating the parameters themselves as 
datapoints. 

A common setting in which such an approach makes sense is when we have J > 1 related datasets, 
Dj, each with their own parameters 0,;. Inferring p(6;|D,;) independently for each group j can give 
poor results if Dj is a small dataset (e.g., if condition j corresponds to a rare combination of features, 
or a sparsely population region). We could of course pool all the data to compute a single model, 
p(@|D), but that would not let us model the subpopulations. A hierarchical Bayesian model lets us 
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borrow statistical strength from groups with lots of data (and hence well-informed posteriors 
p(0;|D)) in order to help groups with little data (and hence highly uncertain posteriors p(0;|D)). 
The idea is that well-informed groups j will have a good estimate of 0j, from which we can infer 
Q, which can be used to help estimate 0; for groups k with less data. (Information is shared via 
the hidden common parent node @ in the graphical model, as shown in Figure 3.12.) We give some 
examples of this below. 

After fitting such models, we can compute two kinds of posterior predictive distributions. If we 
want to predict observations for an existing group j, we need to use 


p(yj|D) = / p(u;|0;)p(8;|D) a6, (3.173) 


However, if we want to predict observations for a new group * that has not yet been measured, but 
which is comparable to (or exchangeable with) the existing groups 1 : J, we need to use 


p(ys[D) = I ANAT. (3.174) 


We give some examples below. (More information can be found in e.g., [GH07; Gel+14a].) 


3.7.1 A hierarchical binomial model 


Suppose we want to estimate the prevalence of some disease amongst different group of individuals, 
either people or animals. Let N; be the size of the j’th group, and let y; be the number of positive 
cases for group j = 1 : J. We assume yj ~ Bin(N,,6;), and we want to estimate the rates 6;. Since 
some groups may have small population sizes, we may get unreliable results if we estimate each 0; 
separately; for example we may observe y; = 0 resulting in Ê; = 0, even though the true infection 
rate is higher. 

One solution is to assume all the 0; are the same; this is called parameter tying. The resulting 


pooled MLE is just Êpooled = = j a But the assumption that all the groups have the same rate is a 
gtd 


rather strong one. A compromise approach is to assume that the 0; are similar, but that there may 
be group-specific variations. This can be modeled by assuming the 0; are drawn from some common 
distribution, say 0; ~ Beta(a, b). The full joint distribution can be written as 


J J 
p(D, 8, $) = v(b)p(4l¢)p(P|9) = ph) | |] Beta(o;|)} | [I Bin(us|Nj,4;) (3.175) 
j=l j=l 
where @ = (a,b). In Figure 3.12 we represent these assumptions using a directed graphical model 
(see Section 4.2.8 for an explanation of such diagrams). 


It remains to specify the prior p(@). Following [Gel+14a, p110], we use 


pla, b) œ (a +b)? (3.176) 
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Figure 3.12: PGM for a hierarchical binomial model. (a) “Unrolled” model. (b) Same model, using plate 
notation. 
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Figure 3.13: Data and inferences for the hierarchical binomial model fit using HMC. Generated by hierarchi- 
cal_binom_ rats.ipynb. 
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3.7.1.1 Posterior inference 


We can perform approximate posterior inference in this model using a variety of methods. In 
Section 3.8.1 we discuss an optimization based approach, but here we discuss one of the most popular 
methods in Bayesian statistics, known as HMC or Hamiltonian Monte Calro. This is described 
in Section 12.5, but in short it is a form of MCMC (Markov Chain Monte Carlo) that exploits 
information from the gradient of the log joint to guide the sampling process. This algorithm generates 
samples in an unconstrained parameter space, so we need to define the log joint over all the parameters 


w = (0, ġ) € RP as follows: 


log p(D, w) = log p(D|@) + log p(8|4) + log p(¢) (3.177) 
J 2 
+ "og |Jac(a)(0;)| + J tog Jac(o+)(d5)| (3.178) 


where 0; = o(0;) is the sigmoid transform, and ¢; = o+(¢;) is the softplus transform. (We need 
to add the Jacobian terms to account for these deterministic transformations.) We can then use 
automatic differentation to compute Vw log p(D,w), which we pass to the HMC algorithm. This 
algorithm returns a set of (correlated) samples from the posterior, (@*,@°) ~ p(w|D), which we can 
back transform to (°, 0°). We can then estimate the posterior overquantities of interest by using a 
Monte Carlo approximation to p(f(@)|D) for suitable f (e.g., to compute the posterior mean rate for 
group j, we set f(@) = 0;). 


3.7.1.2 Example: the rats dataset 


In this section, we apply this model to analyse the number of rats that develop a certain kind of 
tumor during a particular clinical trial (see [Gel+ 14a, p102] for details). We show the raw data in 
rows 1-2 of Figure 3.13a(a),. In row 3 of Figure 3.13a(a) we show the MLE 6; for each group. We 
see that some groups have 6; = 0, which is much less than the pooled MLE Êpooled (red line). In 
row 4 of Figure 3.13a(a) we show the posterior mean E[6;|D] estimated from all the data, as well as 
the population mean E[6|D] = E[a/(a + 6)|D] shown in the red line, We see that groups that had 
low counts have their estimates increased towards the population mean, and groups that have large 
counts have their estimates decreased towards the population mean. In other words, the groups 
regularize each other; this phenomenon is called shrinkage. The amount of shrinkage is controlled 
by the prior on (a,b), which is inferred from the data. 

In Figure 3.13a(b), we show the 95% credible intervals for each parameter, as well as the overall 
population mean. (This is known as a forest plot.) We can use this to decide if any group is 
significantly different than any specified target value (e.g., the overall average). 


3.7.2 A hierarchical Gaussian model 


In this section, we consider a variation of the model in Section 3.7.1, where this time we have 
real-valued data instead of binary count data. More specificially we assume y;; ~ N(0;,07), where 
6; is the unknown mean for group j, and ø? is the observation variance (assumed to be shared across 


groups and fixed, for simplicity). Note that having Nj; observations y;; each with variance ø? is like 


having one measurement yj = 3-0), Yiz with variance o? £ ø?/N;. This lets us simplify notation 
J 
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Figure 3.14: 8-schools dataset. (a) Raw data. Each row plots yj +0,;. Vertical line is the pooled estimate. (b) 
Posterior 95% credible intervals for 0;. Vertical line is posterior mean E [u|D]. Generated by schools8.ipynb. 
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Figure 8.15: Marginal posterior density p(t|D) for the 8-schools dataset. Generated by schools8.ipynb. 


and use one observation per group, with likelihood y; ~ N(6, o°), where we assume the o;’s are 
known. 
We will use a hierarchical model by assuming each group’s parameters come from a common 


39 distribution, 0; ~ N (u, T°). The model becomes 


J 
plp, T°, Or:s|D) œ plu)pl(T?) | [N 0u, TN (y519;,.03) (3.179) 
j=1 
37 where p(u)p(T?) is some kind of prior over the hyper-parameters. See Figure 3.17a for the graphical 
model. 


“ 3.7.2.1 Example: the 8-schools dataset 


Let us now apply this model to some data. We will consider the eight schools dataset from [Gel+ |4a, 


43 Sec 5.5]. The goal is to estimate the effects on a new coaching program on SAT scores. Let yn; be 


the observed improvement in score for student n in school j compared to a baseline. Since each 
school has multiple students, we summarize its data using the empirical mean ¥_; = Ny Sy Ynj 


46 and standard deviation gj. See Figure 3.14a for an illustration of the data. We also show the pooled 
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MLE for 6, which is a precision weighted average of the data: 


Via Ty 
TES (3.180) 
De TA 
We see that school 0 has an unusually large improvement (28 points) compared to the overall mean, 
suggesting that the estimating 09 just based on Do might be unreliable. However, we can easily apply 
our hierarchical model. We will use HMC to do approximate inference. (See Section 3.8.2 for a faster 
approximate method.) 

After computing the (approximate) posterior, we can compute the marginal posteriors p(6;|D) 
for each school. These distributions are shown in Figure 3.14b. Once again, we see shrinkage 
towards the global mean g = E [j1|D], which is close to the pooled estimate y . In fact, if we fix the 
hyper-parameters to their posterior mean values, and use the approximation 


plu, T°|D) = 5(u — @)6(r? — 7°) (3.181) 
then we can use the results from Section 3.2.3.1 to compute the marginal posteriors 


p(9;|D) ~ p(9;|D;, 7,77) (3.182) 


In particular, we can show that the posterior mean E [6;|D] is in between the MLE 6; = yj and the 
global mean F = E [uD]: 


i [GD 2,77] = wa + (1 — w;)ô; (3.183) 


where the amount of shrinkage towards the global mean is given by 


j 
wj = = 3.184 
J oF As 72 ( ) 
Thus we see that there is more shrinkage for groups with smaller measurement precision (e.g., due to 
smaller sample size), which makes intuitive sense. There is also more shrinkage if 7°? is smaller; of 


course 7° is unknown, but we can compute a posterior for it, as shown in Figure 3.15. 


3.7.2.2 Non-centered parameterization 


It turns out that posterior inference in this model is difficult for many algorithms because of the 
tight dependence between the variance hyper parameter 7? and the group means 6,, as illustrated 
by the funnel shape in Figure 3.16. In particular, consider making local move through parameter 
space. The algorithm can only “visit” the place where 7? is small (corresponding to strong shrinkage 
to the prior) if all the @; are close to the prior mean js. It may be hard to move into the area where 
T? is small unless all groups simultaneously move their 0; estimates closer to p. 

A standard solution to this problem is to rewrite the model using the following non-centered 
parameterization: 


0; =u+ Ty (3.185) 
ny ~ N(0,1) (3.186) 
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Figure 3.16: Posterior p(@0,log(r)|D) for the 8 schools model using (a) centered parameterization and (b) 
non-centered parameterization. Generated by schools8.ipynb. 


(a) 


29 Figure 3.17: A hierarchical Gaussian Bayesian model. (a) Centered parameterization. (b) Non-centered 


parameterization. 


See Figure 3.17b for the corresponding graphical model. By writing 0; as a deterministic function of 
its parents plus a local noise term, we have reduced the dependence between 0; and 7 and hence the 
other 6; variables, which can improve the computational efficiency of inference algorithms, as we 
discuss in Section 12.6.5. This kind of reparameterization is widely used in hierarchical Bayesian 
models. 


42 3.7.3 Hierarchical conditional models 


In Section 15.5, we discuss hierarchical Bayesian GLM models, which learn conditional distributions 


45 p(y|x,@;) for each group j, using a prior of the form p(0,|@). In Section 17.6, we discuss hierarchical 
46 Bayesian neural networks, which generalize this idea to nonlinear predictors. 
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3.8. EMPIRICAL BAYES 


3.8 Empirical Bayes 


In Section 3.7, we discussed hierarchical Bayes as a way to infer parameters from data. Unfortunately, 
posterior inference in such models can be computationally challenging. In this section, we discuss 
a computationally convenient approximation, in which we first compute a point estimate of the 
hyperparameters, @, and then compute the conditional posterior, p(O\d, D), rather than the joint 
posterior, p(@, @|D). 

To estimate the hyper-parameters, we can maximize the marginal likelihood: 


Pami (D) = argmaxp(D|) = argmax J p(D|A)p(6|)d0 (3.187) 


This technique is known as type II maximum likelihood, since we are optimizing the hyper- 
parameters, rather than the parameters. Once we have estimated d, we compute the posterior 
polo, D) in the usual way. This is easy to do, if the model is conjugate conditional on the hyper- 
parameters. 

Since we are estimating the prior parameters from data, this approach is empirical Bayes (EB) 
[CL96]. This violates the principle that the prior should be chosen independently of the data. 
However, we can view it as a computationally cheap approximation to inference in the full hierarchical 
Bayesian model, just as we viewed MAP estimation as an approximation to inference in the one level 
model 0 > D. In fact, we can construct a hierarchy in which the more integrals one performs, the 
“more Bayesian” one becomes, as shown below. 


Method Definition 

Maximum likelihood 6 = argmaxy p(D|@) 

MAP estimation 6 = argmaxg p(D|0)p(0|d) 

ML-II (Empirical Bayes) = argmaxg f p(D|0)p(0|¢)d0 
MAP-II Q = argmaxy J p(D|0)p(6|b)p(b)de 
Full Bayes p(9, pD) x p(D|@)p(6|b)p(p) 


Note that ML-II is less likely to overfit than “regular” maximum likelihood, because there are 
typically fewer hyper-parameters @ than there are parameters 0. We give some simple examples 
below, and will see some ML applications later in the book. 


3.8.1 EB for the hierarchical binomial model 


In this section, we revisit the hierarchical binomial model from Section 3.7.1, but we use empirical 
Bayes instead of full Bayesian inference. We can analytically integrate out the @;’s, and write down 
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Figure 3.18: Data and inferences for the hierarchical binomial model fit using empirical Bayes. Generated by 
eb_ binom.ipynb. 


the marginal likelihood directly: The resulting expression is 


P(D|¢) = [I / BintuslN;,6;)Beta(dyla b)d6; one) 
E I B(a + n a — yj) (3.189) 
-JI T(a+b)T(a+y;T(b +N; -— y;) (3.190) 


7 ; T (a)r (b) T(a+b+N;) 


Various ways of maximizing this marginal likelihood wrt a and b are discussed in [Min00c]. 

Having estimated the hyper-parameters a and b, we can plug them in to compute the posterior 
p(0;|â, b, D) for each group, using conjugate analysis in the usual way. We show the results in 
Figure 3.18; they are very similar to the full Bayesian analysis shown in Figure 3.13, but the EB 


41 method is much faster. 


3.8.2 EB for the hierarchical Gaussian model 


In this section, we revisit the hierarchical Gaussian model from Section 3.7.2.1. However, we fit the 
model using empirical Bayes. 
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3.8. EMPIRICAL BAYES 


For simplicity, we will assume that o? = g? is the same for all groups. When the variances are 
equal, we can derive the EB estimate in closed form, as we now show. We have 


p(y;|u, T o°) = [wiles oN (0;|u, T?)d0; = N (y sles, pe ag o°) (3.191) 


Hence the marginal likelihood is 


J 
P(D\u, 77,07) = | [Muslu T? +07) (3.192) 


j=1 


Thus we can estimate the hyper-parameters using the usual MLEs for a Gaussian. For u, we have 
id 
i= >> uj =9 (3.193) 


which is the overall mean. For 7°, we can use moment matching, which is equivalent to the MLE for 
a Gaussian. This means we equate the model variance to the empirical variance: 


qu 
a2) 2 a2 A 
P +o? = 5 g? £v (3.194) 
j=l 
so 7? = v —g?. Since we know 7? must be positive, it is common to use the following revised estimate: 


F 


? = max{0, v — o°} = (v — o°) (3.195) 
Given this, the posterior mean becomes 
6; = àu + (1 — Aju; = u + (1 — A)(y; — u) (3.196) 


where \; = à = o° / (0? + 7°). 
Unfortunately, we cannot use the above method on the 8-schools dataset in Section 3.7.2.1, since it 
uses unequal gj. However, we can still use the EM algorithm or other optimization based methods. 


3.8.3 EB for Markov models (n-gram smoothing) 


The main problem with add-one smoothing, discussed in Section 2.6.3.3, is that it assumes that 
all n-grams are equally likely, which is not very realistic. A more sophisticated approach, called 
deleted interpolation [CG96], defines the transition matrix as a convex combination of the bigram 
frequencies fje = Nj,/N; and the unigram frequencies f, = N;,/N: 
N; N 
Ajk = (1 A) fir + Afe = 1-A) 2E aE (3.197) 


N N 


The term åA is usually set by cross validation. There is also a closely related technique called backoff 
smoothing; the idea is that if fj, is too small, we “back off” to a more reliable estimate, namely fp. 
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Figure 3.19: A Markov chain in which we put a different Dirichlet prior on every row of the transition matrix 
A, but the hyperparameters of the Dirichlet are shared. 


We now show that this heuristic can be interpreted as an empirical Bayes approximation to a 
hierarchical Bayesian model for the parameter vectors corresponding to each row of the transition 
matrix A. Our presentation follows [MP95]. 

First, let us use an independent Dirichlet prior on each row of the transition matrix: 


A; ~ Dir(aomı,..., aomg) = Dir(agm) = Dir(a) (3.198) 


where A, is row j of the transition matrix, m is the prior mean (satisfying }>, mz = 1) and ap is the 
prior strength (see Figure 3.19). In terms of the earlier notation, we have 6; = A; and @ = (a, m). 
The posterior is given by A; ~ Dir(a+N;), where N; = (Nj1,..-,.NjK) is the vector that records 


=~ the number of times we have transitioned out of state j to each of the other states. The posterior 
= predictive density is 


: Nik + Qa;Mr Sip N; tajme 
X k| X. ,D J d =: J 3.199 

P(X = kl Xt = j, D) Nj Fan Nee (3.199) 

= (1 — Aj) fjk + AjMk (3.200) 

where 
as 
Ay = a: (3.201) 
j 


“ This is very similar to Equation (3.197) but not identical. The main difference is that the Bayesian 
== model uses a context-dependent weight A; to combine mz with the empirical frequency f;,, rather 
232 than a fixed weight A. This is like adaptive deleted interpolation. Furthermore, rather than backing 
= off to the empirical marginal frequencies fk, we back off to the model parameter mx. 


The only remaining question is: what values should we use for a and m? Let’s use empirical Bayes. 


22 Since we assume each row of the transition matrix is a priori independent given a, the marginal 
22 likelihood for our Markov model is given by 


p(D|a) = II a (3.202) 
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3.9. MODEL SELECTION AND EVALUATION 


where N; = (Nj1,...,Nj«) are the counts for leaving state j and B(aq@) is the generalized beta 
function. 

We can fit this using the methods discussed in [Min00c]. However, we can also use the following 
approximation [MP95, p12]: 


mr x |{j : Njx > O}| (3.203) 


This says that the prior probability of word k is given by the number of different contexts in which 
it occurs, rather than the number of times it occurs. To justify the reasonableness of this result, 
MacKay and Peto [MP95] give the following example. 


Imagine, you see, that the language, you see, has, you see, a 
frequently occuring couplet ’you see’, you see, in which the second 
word of the couplet, see, follows the first word, you, with very high 
probability, you see. Then the marginal statistics, you see, are going 
to become hugely dominated, you see, by the words you and see, with 
equal frequency, you see. 


If we use the standard smoothing formula, Equation (3.197), then P(you|novel) and P(see|novel), 
for some novel context word not seen before, would turn out to be the same, since the marginal 
frequencies of ’you’ and ’see’ are the same (11 times each). However, this seems unreasonable. ’You’ 
appears in many contexts, so P(you|novel) should be high, but ’see’ only follows ’you’, so P(see|novel) 
should be low. If we use the Bayesian formula Equation (3.200), we will get this effect for free, since 
we back off to mz not fk, and mpg will be large for ’you’ and small for ’see’ by Equation (3.203). 

Although elegant, this Bayesian model does not beat the state-of-the-art language model, known 
as interpolated Kneser-Ney [KN95; CG98]. By using ideas from nonparametric Bayes, one 
can create a language model that outperforms such heuristics, as discussed in [Teh06; Woo+09al. 
However, one can get even better results using recurrent neural nets (Section 16.3.4); the key to their 
success is that they don’t treat each symbol “atomically”, but instead learn a distributed embedding 
representation, which encodes the assumption that some symbols are more similar to each other than 
others. 


3.9 Model selection and evaluation 
All models are wrong, but some are useful. — George Box [BD87, p424].° 


In this section, we assume we have a set of different models M, each of which may fit the data to 
different degrees, and each of which may make different assumptions. We discuss how to pick the 
best model from this set, or to identify that none of them may be adequate. 


3.9.1 Bayesian model selection 


The natural way to pick the best model is to pick the most probable model according to Bayes rule: 


m = argmax p(m|D) (3.204) 
meM 


8. George Box is a retired statistics professor at the University of Wisconsin. 
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where 


p(D|m)p(m) 
mem P(D|m)p(m) 


p(m|D) = (3.205) 


is the posterior over models. This is called Bayesian model selection. If the prior over models is 
uniform, p(m) = 1/|M|, then the MAP model is given by 


m = argmax p(D|m) (3.206) 
meM 

The quantity p(D|m) is given by 

p(Plm) = f p(D|8,m)p(8lm)a9 (3.207) 


This is known as the marginal likelihood, or the evidence for model m. (See Section 3.9.2 for 
details on how to compute this quantity.) If the model assigns high prior predictive density to the 
observed data, then we deem it a good model. If, however, the model has too much flexibility, then 
some prior settings will not match the data; this probability mass will be “wasted”, lowering the 
expected likelihood. This implicit regularization effect is called the Bayesian Occam’s razor. 
Note that Bayesian hypothesis testing can be considered as a special case of Bayesian model 
selection when we just have two models, commonly called the null hypothesis, Mo, and the 
alternative hypothesis, Mı. Let us define the Bayes factor as the ratio of marginal likelihoods: 


Bio ê p(D|\Mi) _ DE) 
°° p(D|Mo)  p(MolD) 


p(Mı) 
p(Mo) 


(3.208) 


(This is like a likelihood ratio, except we integrate out the parameters, which allows us to compare 


~— models of different complexity.) If By, > 1 then we prefer model 1, otherwise we prefer model 0. By 
~~ choosing the appropriate threshold on the Bayes factor, we can achieve any desired false positive vs 


~~ false negative rate. 


= 3.9.2 Estimating the marginal likelihood 


38 If we use a conjugate prior, we can compute the marginal likelihood analytically, as we discussed in 
39 Section 3.2. However, in general, we must use numerical methods to approximate the integral in 
40 Equation (3.207). 


A particularly simple estimator, known as the harmonic mean estimator, was proposed in 


42 [NR94]. It is defined as follows: 


(D) = ee p (3.209) 
PUN S & p(DIO,) i 
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3.9. MODEL SELECTION AND EVALUATION 


= run 3 
=] run 4 


Figure 3.20: Schematic of 5-fold cross validation. 


where 8, ~ p(0) are samples from the prior. This follows from the following identity: 


|| = | roi?) (3.210) 
1 p(D\A)p(9) 

=| 5010) a) (3.211) 

E 5) frone 7 TD) (3.212) 


(We have assumed the prior is proper, so it integrates to 1.) Unfortunately, the number of samples 
needed to get a good estimate is generally very large, since most samples from the prior will have 
low likelihood, making this approach useless in practice. Indeed, Radford Neal made a blog post in 
which he described this method as “The Worst Monte Carlo Method Ever”.?) Fortunately, various 
better estimators are available, such as variational Bayes (Section 10.2.3, sequential Monte Carlo 
(Chapter 13), etc. (See e.g., [GM98; FW12] for more details.) 


3.9.3 Connection between cross validation and marginal likelihood 


A standard approach to model evaluation is to estimate its predictive performance (in terms of log 
likelihood) on a validation set, which is distinct from the training set which is used to fit the model 
If we don’t have such a separate validation set, we can make one by partitioning the training set into 
K subsets or “folds”, and then training on K — 1 and testing on the K’th; we repeat this K times, 
as shown in Figure 3.20. This is known as cross validation. 

If we set K = N, the method is known as leave-one-out cross validation or LOO-CV, since 
we train on N — 1 points and test on the remaining one, and we do this N times. More precisely, we 
have 


N 
Lyoo(m) £ X` log p(Dn|@(D-n),m) (3.213) 


where Ôn is the parameter estimate computing when we omit Dp from the training set. (We discuss 
fast approxmations to this in Section 3.9.5.) 


9. https://bit.ly/3t7idOk. 
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Interestingly, the LOO-CV version of log likelihood is closely related to the log marginal likelihood. 
To see this, let us write the log marginal likelihood (LML) in sequential form as follows: 


N N 
LML(m) Ê log p(D|m) = log | | p(Pn|Pin-1,m) = XL log p(Dp|Pin—1,™) (3.214) 
n=1 n=1 
where 
P(Dn|Din—1,™) = J P@n10)P0Pina,m)a9 (3.215) 


Note that we evaluate the posterior on the first n — 1 data points and use this to predict the n’th; 
this is called prequential analysis [DV99]. 

Suppose we use a point estimate for the parameters at time n, rather than the full posterior. We 
can then use a plugin approximation to the n’th predictive distribution: 


P(Dr|Piin—1,™) © I P(D,,|9)6(8 — Ôm (Diin—1))dO = p(DnlÊm(Di:n-1)) (3.216) 


Then Equation (3.214) simplifies to 


N 
log p(D|m) ~ X` log p(Dn|(Pin—1),™) (3.217) 


n=1 


This is very similar to Equation (3.213), except it is evaluated sequentially. A complex model will 
overfit the “early” examples and will then predict the remaining ones poorly, and thus will get low 
marginal likelihood as well as a low cross-validation score. See [FH20] for further discussion. 


= 3.9.4 Conditional marginal likelihood 


The marginal likelihood answers the question “what is the likelihood of generating the training data 


31 from my prior?”. This can be suitable for hypothesis testing between different fixed priors, but is 
32 less useful for selecting models based on their posteriors. In the latter case, we are more interested 


in the question “what is the probability that the posterior could generate withheld points from the 
data distribution?”, which is related to the generalization performance of the (fitted) model. In 


35 fact [Lot+22] showed that the marginal likelihood can sometimes be negatively correlated with the 
36 generalization performance, because the first few terms in the LML decomposition may be large and 
37 negative for a model that has a poor prior but which otherwise adapts quickly to the data (by virtue 
38 of the prior being weak). 


A better approach is to use the conditional log marginal likelihood, which is defined as follows 
[Lot+22]: 


N 
CLML(m) = X. log p(Pp|Din—1,m) (3.218) 
n=Kk 


45 where K € {1,...,N} is a parameter of the algorithm. This evaluates the LML of the last N — K 
46 datapoints, under the posterior given by the first K datapoints. We can reduce the dependence on 
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3.9. MODEL SELECTION AND EVALUATION 


the ordering of the datapoints by averaging over orders; if we we set K = N — 1 and average over all 
orders, we get the LOO estimate. 

The CLML is much more predictive of generalization performance than the LML, and is much 
less sensitive to prior hyperparameters. Furthermore, it is easier to calculuate, since we can use a 
straightforward Monte Carlo estimate of the integral, where we sample from the posterior p(@|D<,); 
this does not suffer from the same problems as the harmonic mean estimator in Section 3.9.2 which 
samples from the prior. 


3.9.5 Bayesian leave-one-out (LOO) estimate 


In this section we discuss a computationally efficient method, based on importance sampling, to 
approximate the leave-one-out (LOO) estimate without having to fit the model N times. We focus 
on conditional (supervised) models, so p(D|@) = p(y|a, 0). 

Suppose we have computed the posterior given the full dataset for model m. We can use this 
to evaluate the resulting predictive distribution p(Yn|£n, D, m) for each datapoint n in the dataset. 
This gives the log-pointwise predictive-density or LPPD score: 


LPPD(m aS ga (YnlEn, D, m) )= Se f (Yn|an,0,m)p(O|D, m)d0 (3.219) 


n=1 


We can approximate LPPD with Monte Carlo: 


N Ss 
1 
LPPD(m) ~ X` log (3 SO P(Ynl@n; 8s, m) (3.220) 


n=1 s=1 


where 0, ~ p(@|D, m) is a posterior sample. 
The trouble with LPPD is that it predicts the n’th data point yn using all the data, including yn. 
What we would like to compute is the expected LPPD (ELPD) on future data, (£x, yx): 


ELPD(m) £ Ex, y, log p(y=|£x, D, m) (3.221) 


Of course, the future data is unknown, but we can use a LOO approximation: 


ELPDrLoo(m aS ioul Yn|Ln, D-n; m j= Sve f» (Yn|@n,9,m)p(O|D_n, m)dO (3.222) 


n=1 n=1 


This is a Bayesian version of Equation (3.213). We can approximate this integral using Monte Carlo: 


N S 
1 
ELPDroo(m De ox (Yn|En, Bamra) (3.223) 


where 6, -n ~ p(0|D-n, m). 

The above procedure requires computing N different posteriors, leaving one data point out at a 
time, which is slow. A faster alternative is to compute p(@|D,m) once, and then use importance 
sampling (Section 11.5) to approximate the above integral. More precisely, let f(@) = p(@|D_n,m) be 
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the target distribution of interest, and let g(@) = p(@|D,m) be the proposal. Define the importance 
weight for each sample s when leaving out example n to be 


f(9s) _ P(@s|D-n) _ P(D-nlOs)p(@s)___ p(D) (3.224) 
G(s 
( 


“sn — IO) p(Os|D) D(D-n) (DAs) p(s) 
P(D—n|Os) _ p(D_n|9s) = 1 
xDe) pOll) PO) (3-225) 


We then normalize the weights to get 


A Ws, —n 


ti, 2 = (3.226) 
ON get Wet an 


and use them to get the estimate 


N S 
ELPDīis-Loo(m) = X` log (>: s,—-nP(Ynlan, 0.) (3.227) 
n=1 s=1 


Unfortunately, the importance weights may have high variance, where some weights are much 
larger than others. To reduce this effect, we fit a Pareto distribution (Section 2.2.3.5) to each set 
of weights for each sample, and use this to smooth the weights. This technique is called Pareto 
smoothed importance sampling or PSIS [Veh+15; VGG17]. The Pareto distribution has the 
form 


p(r|u,o, k) =o (1+ k(r —u)ot)- Ve} (3.228) 


where u is the location, ø is the scale, and k is the shape. The parameter values k,, (for each data 


= point n) can be used to assess how well this approximation works. If we find kn > 0.5 for any given 
= point, it is likely an outlier, and the resulting LOO estimate is likely to be quite poor. See [Siv+20] 


for further discussion. 


3.9.6 Information criteria 


An alternative approach to cross validation is to score models using the negative log likelihood (or 


= LPPD) on the training set plus a complexity penalty term: 


L(m) = — log p(D|@, m) + C(m) (3.229) 


32 This is called an information criterion. Different methods use different complexity terms C(m), 
40 as we discuss below. See e.g., [GHV14] for further details. 


A note on notation: it is conventional, when working with information criteria, to scale the NLL 


22 by -2 to get the deviance: 


deviance(m) = —2 log p(D|@, m) (3.230) 


46 This makes the math “prettier” for certain Gaussian models. 
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3.9. MODEL SELECTION AND EVALUATION 


3.9.6.1 Minimum description length (MDL) 


We can think about the problem of scoring different models in terms of information theory (Chapter 5). 
The goal is for the sender to communicate the data to the receiver. First the sender needs to specify 
which model m to use; this takes C(m) = — log p(m) bits (see Section 5.2). Then the receiver can 


fit the model, by computing Om, and can thus approximately reconstruct the data. To perfectly 
reconstruct the data, the sender needs to send the residual errors that cannot be explained by the 
model; this takes 


—L(m) = — log p(D|8,m) = — $ log p(yn|an, 8, m) (3.231) 


bits. (We are ignoring the cost of sending the input features æn, if present.) The total cost is 
Lupi(m) = — log p(D|@,m) + C(m) (3.232) 


Choosing the model which minimizes this cost is known as the minimum description length or 
MDL principle. See e.g., [HY01] for details. 


3.9.6.2 The Bayesian information criterion (BIC) 
The Bayesian information criterion or BIC [Sch78] is similar to the MDL, and has the form 


Laic(m) = —2log p(D|@, m) + Dm log N (3.233) 


where Dm is the degrees of freedom of model m. 

We can derive the BIC score as a simple approximation to the log marginal likelihood. In particular, 
suppose we make a Gaussian approximation to the posterior, as discussed in Section 7.4.3. Then we 
get (from Equation (7.22)) the following: 


‘ a 1 
log p(D|m) ~ log p(D|@map) + log p(Binap) — 5 log |H] (3.234) 


where H is the Hessian of the negative log joint log p(D, 0) evaluated at the MAP estimate Orap: We 
see that Equation (3.234) is the log likelihood plus some penalty terms. If we have a uniform prior, 
p(@) œx 1, we can drop the prior term, and replace the MAP estimate with the MLE, 90, yielding 


A 1 
log p(D|m) = log p(D|0) — 5 log |H| (3.235) 


We now focus on approximating the log |H] term, which is sometimes called the Occam factor, 
since it is a measure of model complexity (volume of the posterior distribution). We have H = 


yi H;, where H; = VV log p(D;|9). Let us approximate each H; by a fixed matrix H. Then we 
have 


log |H| = log |NH| = log(N?|H]|) = D log N + log |H| (3.236) 


where D = dim(@) and we have assumed H is full rank. We can drop the log |H] term, since it is 
independent of N, and thus will get overwhelmed by the likelihood. Putting all the pieces together, 
we get the BIC score that we want to maximize: 


‘ Dm 
Jpic(m) = log p(D|0, m) — sf log N (3.237) 
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We can also define the BIC loss, that we want to minimize, by multiplying by -2: 


Laic(m) = —2log p(D|Î, m) + Dm log N (3.238) 


3.9.6.3 Akaike information criterion 


The Akaike information criterion [Aka74] is closely related to BIC. It has the form 
Laic(m) = —2log p(D|@,m) + 2Dm (3.239) 


This penalizes complex models less heavily than BIC, since the regularization term is independent of 
N. This estimator can be derived from a frequentist perspective. 


3.9.6.4 Widely applicable information criterion (WAIC) 


The main problem with MDL, BIC, and AIC is that it can be hard to compute the degrees of a 
freedom of a model, needed to define the complexity term, since most parameters are highly correlated 
and not uniquely identifiable from the likelihood. In particular, if the mapping from parameters 
to the likelihood is not one-to-one, then the model known as a singular statistical model, since 
the corresponding Fisher information matrix (Section 2.4), and hence the Hessian H above, may be 
singular (have determinant 0). An alternative criterion that works even in the singular case is known 
as the widely applicable information criterion (WAIC), also known as the Watanabe—A kaike 
information criterion [Wat10; Wat13]. 

WAIC is like other information criteria, except it is more Bayesian. First it replaces the log 
likelihood L(m), which uses a point estimate of the parameters, with the LPPD, which marginalizes 
them out. (see Equation (3.220)). For the complexity term, WAIC uses the variance of the predictive 
distribution: 


N N 
C(m) = X Voip mllog p(yn|an, 0, m)] ~ XC Vilog p(yn|an, 0s, m) :s=1: S} (3.240) 
n=1 n=1 


29 Lhe intuition for this is as follows: if, for a given datapoint n, the different posterior samples 0, 


make very different predictions, then the model is uncertain, and likely too flexible. The complexity 
term essentially counts how often this occurs. The final WAIC loss is 


Lwaic(m) = —2LPPD(m) + 2C(m) (3.241) 


21 Interestingly, it can be shown that the PSIS LOO estimate in Section 3.9.5 is asymptotically equivalent 
38 to WAIC [VGGI17]. 
~ 3.9.7 Posterior predictive checks 


42 Bayesian inference and decision making is optimal, but only if the modeling assumptions are correct. 
43 In this section, we discuss some ways to assess if a model is reasonable. 


From a Bayesian perspective, this can seem a bit odd, since if we knew there was a better model, 


45 why don’t we just use that? Here we assume that we do not have a specific alternative model in 
46 mind (so we are not performing model selection, unlike Section 3.9.1) Instead we are just trying to 
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3.9. MODEL SELECTION AND EVALUATION 


see if the data we observe is “typical” of what we might expect if our model were correct. This is 
called model checking. 

In particular, suppose we knew the true parameters 0, which we use to generate S synthetic 
datasets, D® = {y® ~ p(-|0) : n = 1: N}; these represent “plausible hallucinations” of the model. To 
assess the quality of our model, we can compute how “typical” our observed data D is compared to 
the model’s hallucinations. To perform this comparison, we create one or more scalar test statistics, 
test(D*), and compare them to the test statistics on the actual data, test(D). These statistics should 
measure features of interest (since it will not, in general, be possible to capture every aspect of the 
data with a given model). If there is a large difference between the distribution of test(D*) across 
different s and the value of test(D), it suggests the model is not a good one. This approach called a 
posterior predictive check [Rub84]. 


3.9.7.1 Example: 1d Gaussian 


To make things clearer, let us consider an example from [Gel+04]. In 1882, Newcomb measured the 
speed of light using a certain method and obtained N = 66 measurements, shown in Figure 3.21(a). 
There are clearly two outliers in the left tails, suggesting that the distribution is not Gaussian. Let 
us nonetheless fit a Gaussian to it. For simplicity, we will just compute the MLE, and use a plug-in 
approximation to the posterior predictive density: 


pD) = N (GA, ô?) B= > om, 6 ==> On — A (3.242) 


Let D: be the s’th dataset of size N = 66 sampled from this distribution, for s = 1 : 1000. The 
histogram of D* for some of these samples is shown in Figure 3.21(b). It is clear that none of the 
samples contain the large negative examples that were seen in the real data. This suggests the model 
cannot capture the long tails present in the data. (We are assuming that these extreme values are 
scientifically interesting, and something we want the model to capture.) 

A more formal way to test fit is to define a test statistic. Since we are interested in small values, 
let us use 


test(D) = min{y : y € D} (3.243) 
The empirical distribution of test(D*) for s = 1 : 1000 is shown in Figure 3.21(c). For the real data, 


test(D) = —44, but the test statistics of the generated data, test(D), are much larger. Indeed, we see 
that —44 is in the left tail of the predictive distribution, p(test(D)|D). 


3.9.8 Bayesian p-values 


If some test statistic of the oberved data, test(D), occurs in the left or right tail of the predictive 
distribution, then it is very unlikely under the model. We can quantify this using a Bayesian 
p-value, also called a posterior-predictive p-value: 


pp = P(test(D) > test(D)|D) (3.244) 


In contrast, a frequentist p-value is defined as 


pc = P(test(D) > test(D)|6*) (3.245) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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Figure 3.21: (a) Histogram of Newcomb’s data. (b) Histograms of data sampled from Gaussian model. 
(c) Histogram of test statistic on data sampled from the model, which represents p(test(D*)|D), where 
test(D) = min{y € D}. The vertical line is the test statistic on the true data, test(D). (d) Same as (c) except 
test(D) = V{y E D}. Generated by newcomb_ plugin_ demo.ipynb. 


29 where 0* is the true but unknown parameter. The key difference between the Bayesian and classical 


approach is that the Bayesian always conditions on what is known (namely the data D), and never 


31 conditions on what is unknown (namely 6*). 


We can approximate the Bayesian p-value using Monte Carlo integration, as follows: 


S 
pp = 1 I (test(D) > test(D)) p(D|6)p(@|D) do ~ DD (test(D*) > test (D) (3.246) 


Any extreme value for ppg (i.e., a value near 0 or 1) means that the observed data is unlikely under 
the model, as assessed via test statistic test. However, if test(D) is a sufficient statistic of the model, 
it is likely to be well estimated, and the p-value will be near 0.5. For example, in the speed of light 
example, if we define our test statistic to be the variance of the data, test(D) = V{y : y E€ D}, we get 


41 a p-value of 0.48. (See Figure 3.21(d).) This shows that the Gaussian model is capable of representing 


the variance in the data, even though it is not capable of representing the support (range) of the 
data. 

The above example illustrates the very important point that we should not try to assess whether 
the data comes from a given model (for which the answer is nearly always that it does not), but 
rather, we should just try to assess whether the model captures the features we care about. See 
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3.9. MODEL SELECTION AND EVALUATION 


Predicted divorce 


-2 -1 0 1 2 
Observed divorce 


Figure 3.22: Posterior predictive distribution for divorce rate vs actual divorce rate for 50 US states. Both 
axes are standardized (i.e., z-scores). A few outliers are annotated. Adapted from Figure 5.5 of [McE20]. 
Generated by linreg_ divorce_ ppc.ipynb. 


[|Gel+04, ch.6] for a more extensive discussion of this topic. 


3.9.8.1 Example: linear regression 


When fitting conditional models, p(y|x), we will have a different prediction for each input a. We can 
compare the predictive distribution p(y|x,,) to the observed yn to detect places where the model 
does poorly. 

As an example of this, we consider the “waffle divorce” dataset from [McE20, Sec 5.1]. This contains 
the divorce rate Dn, marriage rate Mn and age An at first marriage for 50 different US states. We use 
a linear regression model to predict the divorce rate, p(y = d|a = (a,m)) = N (d|a + Baa + Bmm, 0°), 
using vague priors for the parameters. (In this example, we use a Laplace approximation to 
the posterior, discussed in Section 7.4.3.) We then compute the posterior predictive distribution 
p(y|an,P), which is a 1d Gaussian, and plot this vs each observed outcome yn. 

The result is shown in Figure 3.22. We see several outliers, some of which have been annotated. 
In particular, we see that both Idaho (ID) and Utah (UT) have a much lower divorce rate than 
predicted. This is because both of these states have an unusually large proportion of Mormons. 

Of course, we expect errors in our predictive models. However, ideally the predictive error bars 
for the inputs where the model is wrong would be larger, rather than the model confidently making 
errors. In this case, the overconfidence arises from our incorrect use of a linear model. 
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4 Probabilistic graphical models 


4.1 Introduction 


I basically know of two principles for treating complicated systems in simple ways: the first is 
the principle of modularity and the second is the principle of abstraction. I am an apologist 
for computational probability in machine learning because I believe that probability theory 
implements these two principles in deep and intriguing ways — namely through factorization 
and through averaging. Exploiting these two mechanisms as fully as possible seems to me to 
be the way forward in machine learning. — Michael Jordan, 1997 (quoted in [Fre98]). 


Probabilistic graphical models (PGMs) provide a convenient formalism for defining joint 
distributions on sets of random variables. In such graphs, the nodes represent random variables, and 
the (lack of) edges represent conditional independence (CI) assumptions between these variables. 
A better name for these models would be “independence diagrams”, but the term “graphical models” 
is now entrenched. 

There are several kinds of graphical model, depending on whether the graph is directed, undirected, 
or some combination of directed and undirected, as we discuss in the sections below. More details on 
graphical models can be found in e.g., [KF 09a]. 


4.2 Directed graphical models (Bayes nets) 


In this section, we discuss directed probabilistic graphical models, or DPGM, which are based 
on directed acyclic graphs or DAGs (graphs that do not have any directed cycles). PGMs 
based on a DAG are often called Bayesian networks or Bayes nets for short; however, there is 
nothing inherently “Bayesian” about Bayesian networks: they are just a way of defining probability 
distributions. They are are also sometimes called belief networks. The term “belief” here refers to 
subjective probability. However, the probabilities used in these models are no more (and no less) 
subjective than in any other kind of probabilistic model. 


4.2.1 Representing the joint distribution 


The key property of a DAG is that the nodes can be ordered such that parents come before children. 
This is called a topological ordering. Given such an order, we define the ordered Markov 
property to be the assumption that a node is conditionally independent of all its predecessors in 
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Figure 4.1: Illustration of first and second order Markov models. 


the ordering given its parents, i.e., 
Ti L Lprea(i)\pa(i)|#pa(i) (4.1) 


where pa(i) are the parents of node i, and pred(i) are the predecessors of node i in the ordering. 
Consequently, we can represent the joint distribution as follows (assuming we use node ordering 
1: Na): 


Ne 
P(L1:N@) = p(@1)p(@2|"1)p(x3|21, £2). -P(E NG [£1 Mine) a [[ 2@ilepaw) (4.2) 
i=1 


where p(x;|%pa;)) is the conditional probability distribution or CPD for node i. (The parame- 
ters of this distribution are omitted from the notation for brevity.) 

The key advantage of the representation used in Equation (4.2) is that the number of parameters 
used to specify the joint distribution is substantially less, by virtue of the conditional independence 
assumptions that we have encoded in the graph, than an unstructured joint distribution. To see this, 
suppose all the variables are discrete and have K states each. Then an unstructured joint distribution 
needs O(K¢) parameters to specify the probability of every configuration. By contrast, with a 
DAG in which each node has at most Np parents, we only need O(N@KN?*") parameters, which 
can be exponentially fewer if the DAG is sparse. 

We give some examples of DPGM’s in Section 4.2.2, and in Section 4.2.4, we discuss how to read 
off other conditional independence properties from the graph. 


33 4.2.2 Examples 


1 In this section, we give several examples of models that can be usefully represented as DPGM’s. 


4.2.2.1 Markov chains 


38 We can represent the conditional independence assumptions of a first-order Markov model using the 
39 chain-structured DPGM shown in Figure 4.1(a). Consider a variable at a single time step t, which we 


call the “present”. From the diagram, we see that information cannot flow from the past, £1:+—1, to 


41 the the future, £t+1:r, except via the present, x+. (We formalize this in Section 4.2.4.) This means 
42 that the a; is a sufficient statistic for the past, so the model is first-order Markov. This implies that 
43 the corresponding joint distribution can be written as follows: 


T 
p(ar-r) = p(a1)p(x2|21)p(aa|x2)---p(er|er—1) = (ar) | | peler) (4.3) 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


5 IO lœ IN ID Jor [B® lw N =e 


N JIN JIN JN JIN JR Je JR Je Je JR Je JR j= 
ISIS IS IS IS le le IS le la Te le R IE 


w lw jw N [WO [Ww |b |v 
IS 12 1S 18 lè IS 18 à 


33 


4.2. DIRECTED GRAPHICAL MODELS (BAYES NETS) 


For discrete random variables, we can represent corresponding CPDs, p(x, = kla,-1 = j), as a 
2d table, known as a conditional probability table or CPT, p(x; = k|a:_-1 = j) = jk, where 
0< jk < 1 and Si 6; = 1 (i.e., each row sums to 1). 

The first-order Markov assumption is quite restrictive. If we want to allow for dependencies two 
steps into the past, we can create a Markov model of order 2. This is shown in Figure 4.1(b). The 
corresponding joint distribution has the form 


T 

p(£1:T) = p(£1, £2)p(£3|£1, 22)P(£4|£2, £3) plær|Er-2;, e7-1) = plz1, £2) | [ P(£i|Et-21-1) (4-4) 
t=3 

As we increase the order of the Markov model, we need to add more edges. In the limit, the DAG 


becomes fully connected (subject to being acyclic), as shown in Figure 22.1. However, in this case, 
there are no useful conditional independencies, so the graphical model has no value. 


4.2.2.2 The “Student” network 


Figure 4.2 shows a model for capturing the inter-dependencies between 5 discrete random variables 
related to a hypothetical student taking a class: D = difficulty of class (easy, hard), I = intelligence 
(low, high), G = grade (A, B, C), S = SAT score (bad, good), L = letter of recommendation (bad, 
good). (This is a simplification of the “Student network” from [KF09a, p.281].) The chain rule 
tells us that we can represent the joint as follows: 


p(D,I,G, L, S) = p(L|S, G, D, I) x p(S|G, D, I) x p(G|D, I) x p(D|I) x p(I) (4.5) 


where we have ordered the nodes topologically as I, D, G, S, L. Note that L is conditionally 
independent of all the other nodes earlier in this ordering given its parent G, so we can replace 
p(L|S, G, D, I) by p(L|G). We can simplify the other terms in a similar way to get 


p(D, I, G, L, S) = p(L|G) x p(S|I) x p(G|D, I) x p(D) x p(T) (4.6) 


The ability to simplify a joint distribution in a product of small local pieces is the key idea behind 
graphical models. 

In addition to the graph structure, we need to specify the conditional probability distributions 
(CPDs) at each node. For discrete random variables, we can represent the CPD as a table, which 
means we have a separate row (i.e., a separate categorical distribution) for each conditioning case, 
i.e., for each combination of parent values. We can represent the ith CPT as follows: 


Dijk = p(xi = k|2pa(i) = j) (4.7) 


The matrix 0;., is a row stochastic matrix, that satisfies the properties 0 < ijk < 1 and 
Sh Oijk = 1 for each row j. Here i indexes nodes, i € [Ng]; k indexes node states, k € [Kj], 
where K; is the number of states for node i; and j indexes joint parent states, j € [Ji], where 
Ji = Ipepati) Kp. 

The CPTs for the student network are shown next to each node in Figure 4.2. For example, we see 
that if the class is hard (D = 1) and the student has low intelligence (I = 0), the distribution over 
grades A, B and C we expect is p(G|D = 1, I = 0) = [0.05, 0.25, 0.7]; but if the student is intelligent, 
we get p(G|D = 1, I = 1) = [0.5, 0.3, 0.2]. 
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Low, Easy 
Low, Hard |/0.05 }]0.25 |]0.70 
High, Easy 0.90 10.08 ||0.02 
High, Hard 


Letter 
Bad || Good 
0.10 || 0.90 
0.40 || 0.60 
0.99 || 0.01 


>, Figure 4.2: The (simplified) student network. “Diff” is the difficulty of the class. “Intel” is the intelligence 
—— of the student. “Grade” is the grade of the student in this class. “SAT” is the score of the student on the 
= SAT exam. “Letter” is whether the teacher writes a good or bad letter of recommendation. The circles 
= (nodes) represent random variables, the edges represent direct probabilistic dependencies. The tables inside 


each node represent the conditional probability distribution of the node given its parents. Generated by 


31 student pgm.ipynb. 


The number of parameters in a CPT is O(K”*+), where K is the number of states per node, and 


36 pis the number of parents. Later we will consider more parsimonious representations, with fewer 
37 learnable parameters. (We discuss parameter learning in Section 4.2.7.) 


Once we have specified the model, we can use it to answer probabilistic queries, as we discuss in 
Section 4.2.6. As an example, suppose we observe that the student gets a grade of C. The posterior 
probability that the student is intelligent is just p(J = Intelligent|G = C) = 0.08, since it is more likely 


41 that the low grade is explained by the class being hard (indeed, p(D = H|G = C) = 0.63). However, 
42 now suppose we also observe that the student gets a good SAT score. Now the posterior probability 
43 that the student is intelligent has jumped to p(J = Intelligent|G = C,SAT = Good) = 0.58, and 
44 probability that the class is hard has changed to p(D = Hard|G = C,SAT = Good) = 0.76, as shown 
45 in Figure 4.8. This negative mutual interaction between multiple causes of some observations is called 
46 the explaining away effect, also known as Berkson’s paradox (see Section 4.2.4.2 for details). 
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(a) (b) 


Figure 4.3: (a) Hierarchical latent variable model with 2 layers. (b) Same as (a) but with autoregressive 
connections within each layer. The observed x variables are the shaded leaf nodes at the bottom. The unshaded 
nodes are the hidden z variables. 


4.2.2.3 Sigmoid belief nets 


In this section, we consider a deep generative model of the form shown in Figure 4.3a. This 
corresponds to the following joint distribution: 


Kə Kı D 
p(x, z) = p(z2)p(zı|z2)p(æ|z1) = [[ p) [| (21,%|22) H (xaļzı) (4.8) 
k=1 k=1 d=1 


where æ denotes the visible leaf nodes, and zọ denotes the hidden internal nodes. (We assume there 
are Kọ hidden nodes at level £, and D visible leaf nodes.) 

Now consider the special case where all the latent variables are binary, and all the latent CPDs are 
logistic regression models. That is, 


Ke 


plze|ze+1, 0) = | | Ber(ze,nlo(wy pzer)) (4.9) 
k=1 


where o(u) = 1/(1+e~“) is the sigmoid (logistic) function. The result is called a sigmoid belief 
net [Nea92]. 

At the bottom layer, p(x|z1,0), we use whatever observation model is appropriate for the type of 
data we are dealing with. For example, for real valued data, we might use 


D 


p(@|z1,0) = [TN (wale) a121, exP(1 a,021)) (4.10) 
d=1 


where w1,q,„ are the weights that control the mean of the d’th output, and w1,4,o are the weights 
that control the variance of the d’th output. 

We can also add directed connections between the hidden variables within a layer, as shown in 
Figure 4.3b. This is called a deep autoregressive network or DARN model [Gre+14], which 
combines ideas from latent variable modeling and autoregressive modeling. 

We discuss other forms of hierarchical generative models in Chapter 21. 
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4.2.3 Gaussian Bayes nets 


Consider a DPGM where all the variables are real-valued, and all the CPDs have the following form, 
known as a linear Gaussian CPD: 


p(i|€paay) = N (wil ui + w Lpaciy, 07) (4.11) 


As we show below, multiplying all these CPDs together results in a large joint Gaussian distribution 
of the form p(x) = N (æ|u, £), where x € Re. This is called a directed Gaussian graphical 
model or a Gaussian Bayes net. 

We now explain how to derive u and &, following [SK89, App. B]. For convenience, we will rewrite 
the CPDs in the following form: 


Ti = Hi + D Wi j (£j = Hj) + O74 (4.12) 
jEpa(i) 


where z; ~ N (0,1), g; is the conditional standard deviation of x; given its parents, wi, j is the strength 
of the j — i edge, and u; is the local mean.! 

It is easy to see that the global mean is just the concatenation of the local means, ps = (u1, .--, ng): 
We now derive the global covariance, ©. Let S + diag(a) be a diagonal matrix containing the 
standard deviations. We can rewrite Equation (4.12) in matrix-vector form as follows: 


(x — pp) = W(a—- u) + Sz (4.13) 


where W is the matrix of regression weights. Now let e be a vector of noise terms: e = Sz. We can 


26 rearrange this to get e = (I — W)(a — m). Since W is lower triangular (because wj; = 0 if j <i in 
27 the topological ordering), we have that I — W is lower triangular with 1s on the diagonal. Hence 


1 
a) va I 
= | —W32  —W31 1 (4.14) 
ENa ig —WNg, 2 +: Hine te 1 “Na NG 
Since I — W is always invertible, we can write 
x — p= (I— W)'e £ Ue = USz (4.15) 
where we defined U = (I— W)~1. Hence the covariance is given by 

X = Cov [a] = Cov [æ — u] = Cov [USz] = US Cov [z] SU! = US?UT (4.16) 


43 since Cov [|z] = I. 


45 1. If we do not subtract off the parent’s mean (i.e., if we use z; = pi + yj epa(é) Wi j£j +042), the derivation of © is 
46 much messier, as can be seen by looking at [Bis06, p370]. 
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4.2. DIRECTED GRAPHICAL MODELS (BAYES NETS) 


4.2.4 Conditional independence properties 


We will write x4 La £p|æc if A is conditionally independent of B given C in the graph G. (We 
discuss how to determine whether such a CI property is implied by a given graph in the sections 
below.) Let I(G) be the set of all such CI statements encoded by the graph, and I(p) be the set of all 
such CI statements that hold true in some distribution p. We say that G is an I-map (independence 
map) for p, or that p is Markov wrt G, iff I(G) C I(p). In other words, the graph is an I-map 
if it does not make any assertions of CI that are not true of the distribution. This allows us to 
use the graph as a safe proxy for p when reasoning about p’s CI properties. This is helpful for 
designing algorithms that work for large classes of distributions, regardless of their specific numerical 
parameters. Note that the fully connected graph is an I-map of all distributions, since it makes no CI 
assertions at all, as we show below. We therefore say G is a minimal I-map of p if G is an I-map 
of p, and if there is no G’ C G which is an I-map of p. 

We now turn to the question of how to derive I(G), i.e., which CI properties are entailed by a 
DAG. 


4.2.4.1 Global Markov properties (d-separation) 


We say an undirected path P is d-separated by a set of nodes C (containing the evidence) iff at 
least one of the following conditions hold: 


1. P contains a chain or pipe, s > m > tor s + m + t, where m € C 
2. P contains a tent or fork, s /™N t, where m € E 


3. P contains a collider or v-structure, s Nm, t, where m is not in C and neither is any 
descendant of m. 


Next, we say that a set of nodes A is d-separated from a different set of nodes B given a third 
observed set C iff each undirected path from every node a € A to every node b € B is d-separated by 
C. Finally, we define the CI properties of a DAG as follows: 


Xa Le Xg|Xc <> A is d-separated from B given C (4.17) 


This is called the (directed) global Markov property. 

The Bayes ball algorithm [Sha98] is a simple way to see if A is d-separated from B given C, 
based on the above definition. The idea is this. We “shade” all nodes in C, indicating that they are 
observed. We then place “balls” at each node in A, let them “bounce around” according to some 
rules, and then ask if any of the balls reach any of the nodes in B. The three main rules are shown 
in Figure 4.4. Notice that balls can travel opposite to edge directions. We see that a ball can pass 
through a chain, but not if it is shaded in the middle. Similarly, a ball can pass through a fork, but 
not if it is shaded in the middle. However, a ball cannot pass through a v-structure, unless it is 
shaded in the middle. 

We can justify the 3 rules of Bayes ball as follows. First consider a chain structure X > Y > Z, 
which encodes 


P(x, y, z) = plx)ply|z)p(z|y) (4.18) 
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Figure 4.4: Bayes ball rules. A shaded node is one we condition on. If there is an arrow hitting a bar, it 
means the ball cannot pass through; otherwise the ball can pass through. 


(c) 


Figure 4.5: (a-b) Bayes ball boundary conditions. (c) Example of why we need boundary conditions. Y’ is an 


28 observed child of Y, rendering Y “effectively observed”, so the ball bounces back up on its way from X to Z. 


31 When we condition on y, are x and z independent? We have 


p(z,y,) _ p(x)p(ylx)p(zly) _ plz, y)p(ely) 


p(z, 2ly) = = = = p(z|y)p(z|y) (4.19) 
p(y) p(y) p(y) 
= and therefore X L Z | Y. So observing the middle node of chain breaks it in two (as in a Markov 
= chain). 
Now consider the tent structure X + Y > Z. The joint is 
p(z, y, 2) = ply)plzly)p(zly) (4.20) 


When we condition on y, are x and z independent? We have 


ple, ly) = PERE _ PUPP) — polyol) (4.21) 


and therefore X L Z | Y. So observing a root node separates its children (as in a naive Bayes 


46 classifier: see Section 4.2.8.2). 
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x YZ 
D I 

D I Ss 

D S 

D $ I 

D S LI 
D S$ GI 
D S G,L,I 
D L G 

D L G8 
D L GI 
D L LGS 


Table 4.1: Conditional independence relationships implied by the student DAG (Figure 4.2). Each line has 
the form X L Y|Z. Generated by student_pgm.ipynb. 


Finally consider a v-structure X => Y + Z. The joint is 
p(z, y, z) = p(z)p(z)p(ylz, z) (4.22) 
When we condition on y, are x and z independent? We have 


p(z)p(z)ply|z, z) 


plz, zy) = 4.23 
(x, zly) a) (4.23) 
so X £ Z|Y. However, in the unconditional distribution, we have 

p(x, z) = p(x)p(z) (4.24) 


so we see that X and Z are marginally independent. So we see that conditioning on a common 
child at the bottom of a v-structure makes its parents become dependent. This important effect is 
called explaining away, inter-causal reasoning, or Berkson’s paradox (see Section 4.2.4.2 for 
a discussion). 

Finally, Bayes Ball also needs the “boundary conditions” shown in Figure 4.5(a-b). These rules 
say that a ball hitting a hidden leaf stops, but a ball hitting an observed leaf “bounces back”. To 
understand where this rule comes from, consider Figure 4.5(c). Suppose Y” is a (possibly noisy) copy 
of Y. If we observe Y’, we effectively observe Y as well, so the parents X and Z have to compete to 
explain this. So if we send a ball down X + Y — Y”, it should “bounce back” up along Y’ > Y > Z, 
in order to pass information between the parents. However, if Y and all its children are hidden, the 
ball does not bounce back. 

As an example of the CI statements encoded by a DAG, Table 4.1 shows some properties that 
follow from the student network in Figure 4.2. 


4.2.4.2 Explaining away (Berkson’s paradox) 


In this section, we give some examples of the explaining away phenomenon, also called Berkson’s 
paradox. 
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28 Figure 4.6: Samples from a jointly Gaussian DPGM, p(x, y, z) = N (x| — 5,1)N (y|5, 1)N (z|x + y, 1). (a) 
27 Unconditional marginal distributions, p(x), p(y), p(z). (b) Unconditional joint distribution, p(z, y). (e) 
28 Conditional marginal distribution, p(a|z > 2.5), ply|z > 2.5), p(z|z > 2.5). (d) Conditional joint distribution, 


p(x, y|z > 2.5). Adapted from [Clo20]. Generated by berksons_ gaussian. ipynb. 


As a simple example (from [PM18b, p198]), consider tossing two coins 100 times. Suppose you 


22 only record the outcome of the experiment if at least one coin shows up heads. You should expect 


to record about 75 entries. You will see that every time coin 1 is recorded as tails, coin 2 will be 


22 recorded as heads. If we ignore the way in which the data was collected, we might infer from the fact 


that that coins 1 and 2 are correlated that there is a hidden common cause. However, the correct 


2" explanation is that the correlation is due to conditioning on a hidden common effect (namely the 


decision of whether to record the outcome or not, so we can censor tail-tail events). This is called 
selection bias. 
As another example of this, consider a Gaussian DPGM of the form 


p(a,y,2) =N (z| — 5, DN (y|5, DM (z|x + y, 1) (4.25) 


44 The graph structure is X —> Z + Y, where Z is the child node. Some samples from the unconditional 
45 joint distribution p(x, y, z) are shown in Figure 4.6(a); we see that X and Y are uncorrelated. Now 
46 suppose we only select samples where z > 2.5. Some samples from the conditional joint distribution 
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p(x, y|z > 2.5) are shown in Figure 4.6(d); we see that now X and Y are correlated. This could cause 
us to erroneously conclude that there is a causal relationship, but in fact the dependency is caused 
by selection bias. 


4.2.4.3 Markov blankets 


The smallest set of nodes that renders a node 7 conditionally independent of all the other nodes in 
the graph is called i’s Markov blanket; we will denote this by mb(i). Below we show that the 
Markov blanket of a node in a DPGM is equal to the parents, the children, and the co-parents, i.e., 
other nodes who are also parents of its children: 


mb(i) = ch(i) U pa(i) U copa(i) (4.26) 


See Figure 4.7 for an illustration. 

To see why this is true, let us partition all the nodes into the target node X;, its parents U, its 
children Y, its coparents Z, and the other variables O. Let X_; be all the nodes except X;. Then 
we have 


PX|X—4) = 5 Caan (4.27) 
DXi, U,Y, Z, O) 

= P X, -z UY Z, O) (4.28) 
O PADI PIX, ZPU, Z,0) e 
= Sa: aL: = w ZIP Z,O) | 
p(X) TL, 9071, Z) 
ot = FNL ps = & Zy)] me 
oc p(Xilpa(X.)) T z (4.31) 


Yj; €ch(X;) 


where ch(.X;) are the children of X; and pa(Y;) are the parents of Yj. We see that the terms that do 
not involve X; cancel out from the numerator and denominator, so we are left with a product of 
terms that include X; in their “scope”. Hence the full conditional for node i becomes 


p(x;|e@_4) = P(xi|Lmb(a)) X p (zilEpali)) Il p (£k|Epa(k) (4.32) 
kech(i) 


We will see applications of this in Gibbs sampling (Equation (12.19)), and mean field variational 
inference (Equation (10.28)). 


4.2.4.4 Other Markov properties 


From the d-separation criterion, one can conclude that 


i L nd(i) \ pa(i)|pa(i) (4.33) 
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Figure 4.7: Illustration of the Markov blanket of a node in a directed graphical model. The target node Xj; is 
shown in gray, its parents Uı:m are shown in green, its children Yi:n are shown in blue, and its coparents 
Zi:n:j are shown in red. X; is conditionally independent of all the other variables in the model given these 
variables. Adapted from Figure 13.4b of [RN19]. 


where the non-descendants of a node nd(¢) are all the nodes except for its descendants, nd(i) = 
{1,..., No} \ {i U desc(z)}. Equation (4.33) is called the (directed) local Markov property. For 


23 example, in Figure 4.23(a), we have nd(3) = {1, 2,4}, and pa(3) = 1, so 3 L 2, 4/1. 
g 


A special case of this property is when we only look at predecessors of a node according to some 
topological ordering. We have 


i L pred(i) \ pa(i)|pa(i) (4.34) 


which follows since pred(i) C nd(i). This is called the ordered Markov property, which justifies 
Equation (4.2). For example, in Figure 4.23(a), if we use the ordering 1,2,...,7. we find pred(3) = 


31 {1,2} and pa(3) = 1, so 3 L Q|1. 


We have now described three Markov properties for DAGs: the directed global Markov property 
G in Equation (4.17), the directed local Markov property L in Equation (4.33), and the ordered 
Markov property O in Equation (4.34), It is obvious that G => L => O. What is less obvious, 
but nevertheless true, is that O —> L => G (see e.g., |KF09a] for the proof). Hence all these 
properties are equivalent. 

Furthermore, any distribution p that is Markov wrt a graph can be factorized as in Equation (4.2); 
this is called the factorization property F. It is obvious that O => F, but one can show that 
the converse also holds (see e.g., [KF 09a] for the proof). 


= 4.2.5 Generation (sampling) 


43 It is easy to generate prior samples from a DPGM: we simply visit the nodes in topological order, 


parents before children, and then sample a value for each node given the value of its parents. This 
will generate independent samples from the joint, (v1,...,¢v,) ~ p(a|@). This is called ancestral 
sampling. 
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4.2. DIRECTED GRAPHICAL MODELS (BAYES NETS) 


4.2.6 Inference 


In the context of PGMs, the term “inference” refers to the task of computing the posterior over a 
set of query nodes Q given the observed values for a set of visible nodes V, while marginalizing 


over the irrelevant nuisance variables, R = {1,...,Ng}\{Q,V}: 
pe(Q, vV) dr pe(Q,V, R) 
V = 4.35 
PAIE pe(V) pe(V) a 


(If the variables are continuous, we should replace sums with integrals.) If Q is a single node, then 
pe(Q|V) is called the posterior marginal for node Q. 

As an example, suppose V = a is a sequence of observed sound waves, Q = z is the corresponding 
set of unknown spoken words, and R = r are random “non-semantic” factors associated with the 
signal, such as prosody or background noise. Our goal is to compute the posterior over the words 
given the sounds, while being invariant to the irrelevant factors: 


po(zle) = J7 polz, n]a) = y ee D3 err: (4.36) 


m polz’, r’, x) 


As a simplification, we can “lump” the random factors R into the query set Q to define the complete 
set of hidden variables H = Q U R. In this case, the tasks simpifies to 


polh, x) — polh, x) 
po(@) — Xm Po(h’, x) 
The computational complexity of the inference task depends on the CI properties of the graph, as 
we discuss in Chapter 9. In general it is NP-hard (see Section 9.4.4), but for certain graph structures 
(such as chains, trees and other sparse graphs), it can be solved efficiently (in polynomial) time 
using dynamic programming (see Chapter 9). For cases where it is intractable, we can use standard 
methods for approximate Bayesian inference, which we review in Chapter 7. 


pe(h\x) = 


(4.37) 


4.2.6.1 Example: inference in the Student network 


As an example of inference in PGMs, consider the Student network from Section 4.2.2.2. Suppose we 
observe that the student gets a grade of C. The posterior marginals are shown in Figure 4.8a. We see 
that the low grade could be explained by the class being hard (since p(D = Hard|G = C) = 0.63), 
but is more likely explained by the student having low intelligence (since p(I = High|G = C) = 0.08). 

However, now suppose we also observe that the student gets a good SAT score. The new posterior 
marginals are shown in Figure 4.8b. Now the posterior probability that the student is intelligent has 
jumped to p(I = High|G = C,SAT = Good) = 0.58, since otherwise it would be difficult to explain 
the good SAT score. Once we believe the student has high intelligence, we have to explain the C 
grade by assuming the class is hard, and indeed we find that the probability that the class is hard 
has increased to p(D = Hard|G = C) = 0.76. (This negative mutual interaction between multiple 
causes of some observations is called the explaining away effect, and is discussed in Section 4.2.4.2.) 


4.2.7 Learning 


So far, we have assumed that the structure G and parameters 0 of the PGM are known. However, it 
is possible to learn both of these from data. For details on how to learn G from data, see Section 30.3. 
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Low | High Low | High 
0.92 || 0.08 0.42 || 0.58 


(a) (b) 


Figure 4.8: Illustration of belief updating in the “Student” PGM. The histograms show the marginal distribution 
of each node. Nodes with shaded titles are clamped to an observed value. (a) Posterior after conditioning on 
Grade=C. (b) Posterior after also conditioning on SAT=Good. Generated by student_ pgm.ipynb. 


33 Figure 4.9: A DPGM representing the joint distribution p(y1:n,@1:n,9y,02). Here Os and Oy are global 


parameter nodes that are shared across the examples, whereas £n and yn are local variables. 


Here we focus on parameter learning, i.e., computing the posterior p(@|D, G). (Henceforth we will 


41 drop the conditioning on G, since we assume the graph structure is fixed.) 


We can compute the parameter posterior p(@|D) by treating 0 as “just another hidden variable”, 


43 and then performing inference. However, in the machine learning community, it is more common to 
44 just compute a point estimate of the parameters, such as the posterior mode, Ê = argmax p(6|D). 
45 This approximation is often reasonable, since the parameters depend on all the data, rather than 
46 just a single data point, and are therefore less uncertain than other hidden variables. 
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4.2.7.1 Learning from complete data 


Figure 4.9 represents a graphical model for a typical supervised learning problem. We have N local 
variables, x, and yn, and 2 global variables, corresponding to the parameters, which are shared 
across data samples. The local variables are observed (in the training set), so they are represented 
by solid (shaded) nodes. The global variables are not observed, and hence are represented by empty 
(unshaded) nodes. (The model represents a generative classifier, so the edge is from yn to £n; if we 
are fitting a discriminative classifier, the edge would be from Œn to yn, and there would be no 6, 
prior node.) 

From the CI properties of Figure 4.9, it follows that the joint distribution factorizes into a product 
of terms, one per node: 


N 
p(9,D) = we» [TT (Yn|Oy)p nm) (4.38) 


N N 
‘fe pit (Yn|Oy ] fe. pit (Erlyn, 0) (4.39) 
= [p(O,)p (D, |6,,)] [p(@x)p(Dx 0, )] (4.40) 


where D} = {yn}A] is the data that is sufficient for estimating 0, and D; = {2n,¥n}*_, is the 
data that is sufficient for 0x. 

From Equation (4.40), we see that the prior, likelihood and posterior all decompose or factorize 
according to the graph structure. Thus we can compute the posterior for each parameter independently. 
In general, we have 


- Jve» p(Di|6:) (4.41) 


Hence the likelihood and prior factorizes, and thus so does the posterior. If we just want to compute 
the MLE, we can compute 


Ne 
6 = argmax | | p(D;|6;) (4.42) 
0 


i=l 


We can solve this for each node independently, as we illustrate in Section 4.2.7.2. 


4.2.7.2 Example: computing the MLE for CPTs 
In this section, we illustrate how to compute the MLE for tabular CPDs. The likelihood is given by 
the following product of multinomials: 


N Ne 


p(D|6) = | | [] pl@nilatn pac; 9%) (4.43) 
n=1i=1 
N Ne Ji 


- U i eget (asi) 


n=l i=1 j=1 k=1 
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Table 4.2: Some fully observed training data for the student network. 


E | Nil | iie bijt, 
Fe eT ppg ippa 
0 TL} (0,0,4) ) [5 z3] pai] 

tE ti 
1 0 [1,0,0] repens reper 
1 1 | [0,1,1] [0, 5; 5] laneis] 


Table 4.3: Sufficient statistics Nijk and corresponding MLE Dijk and posterior mean Dijk for node i = G in 
the student network. Each row corresponds to a different joint configuration of its parent nodes, coresponding 
to state j. The index k refers to the 3 possible values of the child node G. 


where 


Dijk = p(z: = k|£pati) = J) (4.45) 


29 Let us define the sufficient statistics for node 7 to be Nij, which is the number of times that node i 


is in state k while its parents are in joint state 7: 


N 
Nijk = 5 I (£ni = k, En pa(i) = j) (4.46) 


n=1 


= The MLE for a multinomial is given by the normalized empirical frequencies: 


A Nijk 
ppp k (4.47) 
De Naas 


For example, consider the student network from Section 4.2.2.2. In Table 4.2, we show some sample 


41 training data. For example, the last line in the tabel encodes a student who is smart (J = 1), who 


takes a hard class (D = 1), gets a C (G = 2), but who does well on the SAT (S = 1) and gets a good 


43 letter of recommendation (L = 1). 


In Table 4.3, we list the sufficient statistics Njj, and the MLE Dijk for node i = G, with parents 


45 (I, D). A similar process can be used for the other nodes. Thus we see that fitting a DPGM with 
46 tabular CPDs reduces to a simple counting problem. 
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Figure 4.10: A DPGM representing the joint distribution p(z1:n,21:n, 92,02). The local variables zn are 
hidden, whereas £n are observed. This is typical for learning unsupervised latent variable models. 


However, we notice there are a lot of zeros in the sufficient statistics, due to the small sample size, 
resulting in extreme estimates for some of the probabilities 6;;,. We discuss a (Bayesian) solution to 
this in Section 4.2.7.3. 


4.2.7.3 Example: Computing the posterior for CPTs 


In Section 4.2.7.2 we discussed how to compute the MLE for the CPTs in a discrete Bayes net. We 
also observed that this can suffer from the zero-count problem. In this section, we show how a 
Bayesian approach can solve this problem. 

Let us put a separate Dirichlet prior on each row of each CPT, i.e., 0i; ~ Dir(a;,;). Then we 
can compute the posterior by simply adding the pseudo counts to the empirical counts to get 
0;;|D ~ Dir(N;; + aij), where Ni; = {Nijn : k = 1: Kj}, and Nijp is the number of times that node 
i is in state k while its parents are in state j. Hence the posterior mean estimate is given by 


Tos Nijk + Qijk 

iik = 

O ee (Nije + aije) 

The MAP estimate has the same form, except we use Qijk — 1 instead of aj;,. 


In Table 4.3, we illustrate this approach applied to the G node in the student network, where we 
use a uniform Dirichlet prior, @ijk = 1. 


(4.48) 


4.2.7.4 Learning from incomplete data 


In Section 4.2.7.1, we explained that when we have complete data, the likelihood (and posterior) 
factorizes over CPDs, so we can estimate each CPD independently. Unfortunately, this is no longer 
the case when we have incomplete or missing data. To see this, consider Figure 4.10. The likelihood 
of the observed data can be written as follows: 


N 
p(P|@) = 5 | [ 2nl6-)p(@n len; 92) (4.49) 


žin Ln=1 


N 
= II XO p(2n|0z)p(@n|Zn, Ox) (4.50) 


n=1 Zn 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e W N e 


m. 
m. 


m= 
N 


k Is le IS ls la le le | 


N 
= 


IS IS 18 R | 


136 


Thus the log likelihood is given by 


= X log X` p(2n|9z)p(@n|2n; Ox) (4.51) 


The log function does not distribute over the pas operation, so the objective does not decompose 
over nodes.” Consequently, we can no longer compute the MLE or the posterior by solving separate 
problems per node. 

To solve this, we will resort to optimization methods. (We focus on the MLE case, and leave 
discussion of Bayesian inference for latent variable models to Part II.) In the sections below, we 
discuss how to use EM and SGD to find a local optimum of the (non convex) log likelihood objective. 


4.2.7.5 Using EM to fit CPTs in the incomplete data case 


A popular method for estimating the parameters of a DPGM in the presence of missing data is to 
the use the expectation maximization (EM) algorithm, as proposed in [Lau95]. We describe EM 
in detail in Section 6.6.3, but the basic idea is to alternate between inferring the latent variables 
Zn (the E or expectation step), and estimating the parameters given this completed dataset (the 
M or maximization step). Rather than returning the full posterior p(z,|a,0™) in the E step, we 
instead return the expected sufficient statistics (ESS), which takes much less space. In the M step, 
we maximize the expected value of the log likelihood of the fully observed data using these ESS. 


As an example, suppose all the CPDs are tabular, as in the example in Section 4.2.7.2. The 
log-likelihood of the complete data is given by 
No Ji Ki 
log p(D|@) = SOD Mar log Oijk (4.52) 
i=1 g=1 k=1 
29 and hence the expected complete data log-likelihood has the form 
i [log p( D|@)| = LD Mis log Oi5% (4.53) 
where 
N N 
Nijk = 5 D [I (Eni = k, @n pa(i) = j)| = X Plani = k, En,pa(i) = j|Dn 0") (4.54) 
n=1 n=1 


where D, are all the visible variables in case n, and 0°% are the parameters from the previous iteration. 


= The quantity p(tnz,2n,pa(a)|Pn, 9°'4) is known as a family marginal, and can be computed using 
= any GM inference algorithm. The N;;; are the expected sufficient statistics (ESS), and constitute 
~ the output of the E step. 


— 2. We can also see this from the graphical model: Oxy is no longer independent of 0z, because there is a path that 
= connects them via the hidden nodes zn. (See Section 4.2.4 for an explanation of how to “read off” such CI properties 
46 from a DPGM.) 
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Given these ESS, the M step has the simple form 
jy Nijk 
ijk = == 
O Ep Naw 
We can modify this to perform MAP estimation with a Dirichlet prior by simply adding pseudo 
counts to the expected counts. 


The famous Baum-Welch algorithm is a special case of the above equations which arises when the 
DPGM is an HMM (Section 29.4.1) 


(4.55) 


4.2.7.6 Using SGD to fit CPTs in the incomplete data case 


The EM algorithm is a batch algorithm. To scale up to large datasets, it is more common to use 
stochastic gradient descent or SGD (see e.g., [BC94; Bin+97]). To apply this, we need to compute 
the marginal likelihood of the observed data for each example: 


plæn|0) = 2 2n|0.)p (Xn |Zn, Ox) (4.56) 


where 0 = (0,,0,).) (We say that we have “collapsed” the model by marginalizing out zn.) We can 
then compute the log likelihood using 


N N 
¢(0) = log p(D|9) = log | | p(an|O= X log p(æn]0) (4.57) 


The gradient of this objective can be computed as follows: 


Vol(O =>, Vo log p(a,,|0) (4.58) 
= oe „lO 4. 
‘= Eere) (4.59) 
1 
= 2, C ž P(Zn, 2,0) (4.60) 
ea ee i 
= SEn Zal Bas 0)Vo log p(Zn, &n|9) (4.62) 


We can now apply a minibatch approximation to this in the usual way. 


4.2.8 Plate notation 


To make the parameters of a PGM explicit, we can add them as nodes to the graph, and treat them 
as hidden variables to be inferred. Figure 4.11(a) shows a simple example, in which we have N iid 
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XY XN 


Figure 4.11: Left: data points x, are conditionally independent given 0. Right: Same model, using plate 
notation. This represents the same model as the one on the left, except the repeated £n nodes are inside a 
box, known as a plate; the number in the lower right hand corner, N, specifies the number of repetitions of 
the £n node. 


random variables, £n, all drawn from the same distribution with common parameter 8. We denote 
this by 
£n ~ p(x) (4.63) 
The corresponding joint distribution over the parameters and data D = {£1,..., £y} has the form 
p(D, 8) = p(0)p(D|0) (4.64) 


32 where p(0) is the prior distribution for the parameters, and p(D|@) is the likelihood. By virtue of the 
33 jid assumption, the likelihood can be rewritten as follows: 


N 
p(D\@) = | | p(æn10) (4.65) 


Notice that the order of the data vectors is not important for defining this model, i.e., we can permute 
the leaves of the DPGM. When this property holds, we say that the data is exchangeable. 
In Figure 4.11(a), we see that the æ nodes are repeated N times. (The shaded nodes represent 


42 observed values, whereas the unshaded (hollow) nodes represent latent variables or parameters.) To 
43 avoid visual clutter, it is common to use a form of syntactic sugar called plates. This is a notational 
44 convention in which we draw a little box around the repeated variables, with the understanding that 
45 nodes within the box will get repeated when the model is unrolled. We often write the number of 
46 copies or repetitions in the bottom right corner of the box. This is illustrated in Figure 4.11(b). 
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(a) (b) 


Figure 4.12: (a) Factor analysis model illustrated as a DPGM. We show the components of z (top row) and x 
(bottom row) as individual scalar nodes. (b) Equivalent model, where z and x are collapsed to vector-valued 
nodes, and parameters are added, using plate notation. 


4.2.8.1 Example: factor analysis 


In Section 28.3.1, we discuss the factor analysis model, which has the form 


p(z) = N (z| Ho, Zo) (4.66) 
plæ|z) = N(a@|Wz + p, 8) (4.67) 


where W is a D x L matrix, known as the factor loading matrix, and W is a diagonal D x D covariance 
matrix. 

Note that z and æ are both vectors. We can explicitly represent their components as scalar nodes 
as in Figure 4.12a. Here the directed edges correspond to non-zero entries in the W matrix. 

We can also explicitly show the parameters of the model, using plate notation, as shown in 
Figure 4.12b. 


4.2.8.2 Example: Naive Bayes classifier 


In some models, we have doubly indexed variables. For example, consider a naive Bayes classifier. 
This is a simple generative classifier, defined as follows: 


©) 


p(æ, y|8) = p(y|m) I» (zaly, 0a) (4.68) 


The fact that the features æı:p are considered conditionally independent given the class label y is 
where the term “naive” comes from. Nevertheless, this model often works surprisingly well, and is 
extremely east to fit. 

We can represent the conditional independence assumption as shown in Figure 4.13a. We can 
represent the repetition over the dimension d with a plate. When we turn to infering the parameters 
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(a) (b) 


Figure 4.13: (a) Naive Bayes classifier as a DPGM. (b) Model augmented with plate notation. 


(b) 


Figure 4.14: Tree-augmented naive Bayes classifier for D = 4 features. The tree topology can change depending 


~ on the value of y, as illustrated. 


0 = (m,91:p,1:c), we also need to represent the repetition over data cases n. This is shown in 
Figure 4.13b. Note that the parameter Oqe depends on d and c, whereas the feature £na depends on 
n and d. This is shown using nested plates to represent the shared d index. 


— 4.2.8.3 Example: relaxing the naive Bayes assumption 


We see from Figure 4.13a that the observed features are conditionally independent given the class 
label. We can of course allow for dependencies between the features, as illustrated in Figure 4.14. 


41 (We omit parameter nodes for simplicity.) If we enforce that the edges between the features forms a 


tree the model is known as a tree-augmented naive Bayes classifier [FGG97|, or TAN model. 


43 (Trees are a restricted form of graphical that have various computational advantages that we discuss 
44 later.) Note that the topology of the tree can change depending on the value of the class node y; 


in this case, the model is known as a Bayesian multi net, and can be thought of as a supervised 


46 mixture of trees. 
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Xı — Xs X, — X; 
| il | 


XG Xs ; Xio 
Xu — X12 Xu — X15 
| | | | 
Xis > X17 > Xis > Xi > Xv Xis — X17 — Xis — X19 — X20 
(a) (e) 


Figure 4.15: (a) A 2d lattice represented as a DAG. The dotted red node Xs is independent of all other nodes 
(black) given its Markov blanket, which include its parents (blue), children (green) and co-parents (orange). 
(b) The same model represented as a UPGM. The red node Xs is independent of the other black nodes given 
its neighbors (blue nodes). 


4.3 Undirected graphical models (Markov random fields) 


Directed graphical models (Section 4.2) are very useful. However, for some domains, being forced to 
choose a direction for the edges, as required by a DAG, is rather awkward. For example, consider 
modeling an image. It is reasonable to assume that the intensity values of neighboring pixels are 
correlated. We can model this using a DAG with a 2d lattice topology as shown in Figure 4.15(a). 
This is known as a Markov mesh [AHK65]. However, its conditional independence properties are 
rather unnatural. 

An alternative is to use an undirected probabilistic graphical model (UPGM), also called a 
Markov random field (MRF) or Markov network. These do not require us to specify edge 
orientations, and are much more natural for some problems such as image analysis and spatial 
statistics. For example, an undirected 2d lattice is shown in Figure 4.15(b); now the Markov blanket 
of each node is just its nearest neighbors, as we show in Section 4.3.6. 

Roughly speaking, the main advantages of UPGMs over DPGMs are: (1) they are symmetric and 
therefore more “natural” for certain domains, such as spatial or relational data; and (2) discriminative 
UPGMs (aka conditional random fields, or CRFs), which define conditional densities of the form 
p(y|a), work better than discriminative DGMs, for reasons we explain in Section 4.5.3. The main 
disadvantages of UPGMs compared to DPGMs are: (1) the parameters are less interpretable and 
less modular, for reasons we explain in Section 4.3.1; and (2) it is more computationally expensive to 
estimate the parameters, for reasons we explain in Section 4.3.9.1. 


4.3.1 Representing the joint distribution 


Since there is no topological ordering associated with an undirected graph, we can’t use the chain 
rule to represent p(®1:yg). So instead of associating CPDs with each node, we associate potential 
functions or factors with each maximal clique in the graph.? We will denote the potential 


3. A clique is a set of nodes that are all neighbors of each other. A maximal clique is a clique which cannot be 
made any larger without losing the clique property. 
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function for clique c by Ye(£e; 0c), where O. are its parameters. A potential function can be any 
non-negative function of its arguments (we give some examples below). We can use these functions 
to define the joint distribution as we explain in Section 4.3.1.1. 


4.3.1.1 Hammersley-Clifford theorem 


Suppose a joint distribution p satisfies the CI properties implied by the undirected graph G. (We 
discuss how to derive these properties in Section 4.3.6.) Then the Hammersley-Clifford theorem 
tells us that p can be written as follows: 


p(x|0) _ AG [[ veles: 0.) (4.69) 


where C is the set of all the (maximal) cliques of the graph G, and Z(@) is the partition function 
given by 


Z(0) = ST ve(we; Oc) (4.70) 
£ cec 


Note that the partition function is what ensures the overall distribution sums to 1.4 

The Hammersley-Clifford theorem was never published, but a proof can be found in [KF09a]. 
(Note that the theorem only holds for positive distributions, i.e., ones where p(a|@) > 0 for all 
configurations æ, which rules out some models with hard constraints.) 


4.3.1.2 Gibbs distribution 


26 The distribution in Equation (4.69) can be rewritten as follows: 


plx|8) = egy exe(-E(@:8)) (4.71) 


where E(x) > 0 is the energy of state x, defined by 
E(x; 0) =F Eat (4.72) 


where £e are the variables in clique c. We can see the equivalence by defining the clique potentials as 


We(®cj Oc) = exp(—E (£0; 8c) (4.73) 


We see that low energy is associated with high probability states. 


Equation (4.71) is known as the Gibbs distribution. This kind of probability model is also called 
an energy-based model. These are commonly used in physics and biochemistry. They are also 
used in ML to define generative models, as we discuss in Chapter 24. (See also Section 4.4, where 


42 We discuss conditional random fields (CRFs), which are models of the form p(y|x,@), where the 


potential functions are conditioned on input features x.) 


— 4. The partition function is denoted by Z because of the German word Zustandssumme, which means “sum over states”. 
= This reflects the fact that a lot of pioneering working on MRFs was done by German (and Austrian) physicists, such as 


46 Boltzmann. 
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4.3. UNDIRECTED GRAPHICAL MODELS (MARKOV RANDOM FIELDS) 


4.3.2 Fully visible MRFs (Ising, Potts, Hopfield, etc) 


In this section, we discuss some UPGMs for 2d grids, that are used in statistical physics and computer 
vision. We then discuss extensions to other graph structures, which are useful for biological modeling 
and pattern completion. 


4.3.2.1 Ising models 


Consider the 2d lattice in Figure 4.15(b). We can represent the joint distribution as follows: 


p(x|@) = zæ [tules (4.74) 


Inj 


where 7 ~ j means i and j are neighbors in the graph. This is called a 2d lattice model. 
An Ising model is a special case of the above, where the variables x; are binary. Such models 

are often used to represent magnetic materials. In particular, each node represents an atom, which 

can have a magnetic dipole, or spin, which is in one of two states, +1 and —1. In some magnetic 

systems, neighboring spins like to be similar; in other systems, they like to be dissimilar. We can 

capture this interaction by defining the clique potentials as follows: 

etis if Ti = Tj 


4.75 
ea if ri £ zj (A 


Wig (ti, 2530) = 
where Jij is the coupling strength between nodes 7 and j. This is known as the Ising model. If 
two nodes are not connected in the graph, we set Jj; = 0. We assume that the weight matrix is 


symmetric, so Jij = Jji. Often we also assume all edges have the same strength, so Jij = J for each 
(i, j) edge. Thus 


e ifaj,= 2; 

ijlt ziJ) = í J 4.76 
Wij (Vi, £j; J) [z if 2, #2; ( ) 

It is more common to define the Ising model as an energy-based model, as follows: 

1 
plae) = zry ex(-€(#: I) (4.77) 
E(x; J) =-J 0 aa; (4.78) 
inj 


where E(æ; J) is the energy, and where we exploited the fact that 2,2; = —1 if x; # £j, and a2; = +1 
if x; = zj. The magnitude of J controls the degree of coupling strength between neighboring sites, 
which depends on the (inverse) temperature of the system (colder = more tightly coupled = larger 
magnitude J). 

If all the edge weights are positive, J > 0, then neighboring spins are likely to be in the same 
state, since if x; = zj, the energy term gets a contribution of —J < 0, and lower energy corresponds 
to higher probability. In the machine learning literature, this is called an associative Markov 
network. In the physics literature, this is called a ferromagnetic model. If the weights are 
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| 
(a) (b) 


Figure 4.16: (a) The two ground states for a small ferromagnetic Ising model where J = 1. (b) Two different 
states for a small Ising model which have the same energy. Left: J =1, so neighboring pixels have similar 
values. Right: J = —1, so neighboring pixels have different values. From Figures 31.7 and 31.8 of [Mac03]. 


= aa 2 
N N 
z > 
0 50 100 
x1 
(c) 
Figure 4.17: Samples from an associative Ising model with varying J > 0. Generated by 


gibbs_ demo_ ising.ipynb. 


sufficiently strong, the corresponding probability distribution will have two modes, corresponding to 
the two checkerboard patterns in Figure 4.16a. These are called the ground states of the system. 

If all of the weights are negative, J < 0, then the spins want to be different from their neighbors (see 
Figure 4.16b). This is called an antiferromagnetic system, and results in a frustrated system, 
since it is not possible for all neighbors to be different from each other in a 2d lattice. Thus the 
corresponding probability distribution will have multiple modes, corresponding to different “solutions” 
to the problem. 

Figure 4.17 shows some samples from the Ising model for varying J > 0. (The samples were 
created using the Gibbs sampling method discussed in Section 12.3.3.) As the temperature reduces, 
the distribution becomes less entropic, and the “clumpiness” of the samples increases. One can show 
that, as the lattice size goes to infinity, there is a critical temperature J, below which many large 
clusters occur, and above which many small clusters occur. In the case of an isotropic square lattice 
model, one can show [Geo88] that 


1 
Jo = 5 log(1 + V2) = 0.44 (4.79) 


45 This rapid change in global behavior as we vary a parameter of the system is called a phase 
46 transition. This can be used to explain how natural systems, such as water, can suddenly go from 
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J= 1.46 


(a) (b) (c) 


Figure 4.18: Visualizing a sample from a 10-state Potts model of size 128 x 128. The critical value is 
Je = log(1 + v10) = 1.426. for different association strengths: (a) J = 1.40, (b) J = 1.43, (c) J = 1.46. 
Generated by gibbs_demo_ potts.ipynb. 


solid to liquid, or from liquid to gas, when the temperature changes slightly. See e.g., [Mac03, ch 31] 
for further details on the statistical mechanics of Ising models. 

In addition to pairwise terms, it is standard to add unary terms, w;(z;). In statistical physics, 
this is called an external field. The resulting model is as follows: 


p(a|@) = za (ai; 0) |] vis (wig: 0) (4.80) 


The y; terms can be thought of as a local bias term, that are independent of the contributions of the 
neighboring nodes. For binary nodes, we can define this as follows: 


e% ifa;=4+1 
wi (xi) = Sa. (4.81) 
e if z; = —1 


If we write this as an energy-based model, we have 


E(x|0) = a — JX itj (4.82) 


inj 


4.3.2.2 Potts models 


In Section 4.3.2.1, we discussed the Ising model, which is a simple 2d MRF for defining distributions 
over binary variables. It is easy to generalize the Ising model to multiple discrete states, x; € 
{1,2,..., K}. if we use the same potential function for every edge, we can write 


bij (vi = k, gj = k') = e" HR) (4.83) 
where J;j(k, k’) is the energy if one node has state k and its neighbor has state k’. A common special 


case is 


e? trey 


e ifkék (ea 


YPij(xi = k, xj = k') = [ 
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This is called the Potts model. The Potts model reduces to the Ising model if we define Jpotts = 
2Jising- 

If J > 0, then neighboring nodes are encouraged to have the same label; this is an example of 
an associative Markov model. Some samples from this model are shown in Figure 4.18. The phase 
transition for a 2d Potts model occurs at the following value (see [MS96]): 


Je = log(1+ VK) (4.85) 


We can extend this model to have local evidence for each node. If we write this as an energy-based 
model, we have 


K 
E(a|0) =—S X or (ai =k) -— J I(t: = z) (4.86) 
i k=1 ing 


4.3.2.3 Potts models for protein structure prediction 


One interesting application of Potts models arises in the area of protein structure prediction. 
The goal is to predict the 3d shape of a protein from its 1d sequence of amino acids. A common 
approach to this is known as direct coupling analysis (DCA). We give a brief summary below; for 
details, see [Mor+11]. 

First we compute a multiple sequence alignment (MSA) from a set of related amino acid 
sequences from the same protein family; this can be done using HMMs, as explained in Section 29.3.2. 
The MSA can be represented by an N x T matrix X, where N is the number of sequences, T is the 
length of each sequence, and Xni € {1,...,V} is the identity of the letter at location 7 in sequence n. 
For protein sequences, V = 21, representing the 20 amino acids plus the gap character. 

Once we have the MSA matrix X, we fit the Potts model using maximum likelihood estimation, or 


28 some approximation, such as pseudo likelihood [Eke+13]; see Section 4.3.9 for details.” After fitting 


the model, we select the edges with the highest J;; coefficients, where i, j € {1,..., T} are locations 
or residues in the protein. Since these locations are highly coupled, they are likely to be in physical 


31 contact, since interacting residues must coevolve to avoid destroying the function of the protein (see 
32 e.g., [LHF17| for a review). This graph is called a contact map. 


Once the contact map is established, it can be used as input to a 3d structural prediction algorithm, 
such as [Xu18] or the alphafold system [Eva+18], which won the 2018 CASP competition. Such 


35 methods use neural networks to learn functions of the form p(d(i, j)|{c(i, 7)}), where d(i, j) is the 3d 
36 distance between residues i and j, and c(i, 7) is the contact map. 


38 4.3.2.4 Hopfield networks 


A Hopfield network [Hop82] is a fully connected Ising model (Section 4.3.2.1) with a symmetric 
weight matrix, W = W!. The corresponding energy function has the form 


ce — 5a" Wa (4.87) 


=2 5. To encourage the model to learn sparse connectivity, we can also compute a MAP estimate with a sparsity promoting 
46 prior, as discussed in [IM17]. 
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where x; € {—1, +1}. 

The main application of Hopfield networks is as an associative memory or content addressable 
memory. The idea is this: suppose we train on a set of fully observed bit vectors, corresponding to 
patterns we want to memorize. (We discuss how to do this below). Then, at test time, we present 
a partial pattern to the network. We would like to estimate the missing variables; this is called 
pattern completion. That is, we want to compute 


x“ = argmin E(x) (4.88) 


We can solve this optimization problem using iterative conditional modes (ICM), in which we 
set each hidden variable to its most likely state given its neighbors. Picking the most probable state 
amounts to using the rule 


rtt! = sgn(W zt) (4.89) 


This can be seen as a deterministic version of Gibbs sampling (see Section 12.3.3). 

We illustrate this process in Figure 4.19. In the top row, we show some training examples. In the 
middle row, we show a corrupted input, corresponding to the initial state 2°. In the bottom row, we 
show the final state after 30 iterations of ICM. The overall process can be thought of as retrieving a 
complete example from memory based on a piece of the example. 

To learn the weights W, we could use the maximum likelihood estimate method described in 
Section 4.3.9.1. (See also [HSDK12].) However, a simpler heuristic method, proposed in [Hop82], is 
to use the following outer product method: 


N 
1 T 
W= GÈ za!) =i (4.90) 


This normalizes the output product matrix by N, and then sets the diagonal to 0. This ensures the 
energy is low for patterns that match any of the examples in the training set. This is the technique we 
used in Figure 4.19. Note, however, that this method not only stores the original patterms but also 
their inverses, and other linear combinations. Consequently there is a limit to how many examples 
the model can store before they start to “collide” in the memory. Hopfield proved that, for random 
patterns, the network capacity is ~ 0.14N. 


4.3.3 MREFs with latent variables (Boltzmann machines, etc) 


In this section, we discuss MRFs which contain latent variables, as a way to represent high dimensional 
joint distributions in discrete spaces. 


4.3.3.1 Vanilla Boltzmann machines 


MRFs in which all the variables are visible are limited in their expressive power, since the only way to 
model correlation between the variables is by directly adding an edge. An alternative approach is to 
introduce latent variables. A Boltzmann machine [AHS85] is like an Ising model (Section 4.3.2.1) 
with latent variables. In addition, the graph structure can be arbitrary (not just a lattice), and the 
binary states are x; € {0,1} instead of x; € {—1,+1}. We usually partition the nodes into hidden 
nodes z and visible nodes x, as shown in Figure 4.20(a). 
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hopfield_training 
0 


Figure 4.19: Examples of how an associative memory can reconstruct images. These are binary images of 
size 150 x 150 pixels. Top: training images. Middle row: partially visible test images. Bottom row: final state 
estimate. Adapted from Figure 2.1 of [HKP91]. Generated by hopfield_ demo.ipynb. 


(a) (b) 


Figure 4.20: (a) A general Boltzmann machine, with an arbitrary graph structure. The shaded (visible) nodes 
are partitioned into input and output, although the model is actually symmetric and defines a joint distribution 
on all the nodes. (b) A restricted Boltzmann machine with a bipartite structure. Note the lack of intra-layer 


= connections. 
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Figure 4.21: Some reconstructed images generated by a binary RBM fit to MNIST. Generated by 
rbm_ contrastive_ divergence.ipynb. 


4.3.3.2 Restricted Boltzmann machines (RBMs) 


Unfortunately, exact inference (and hence learning) in Boltzmann machines is intractable, and even 
approximate inference (e.g., Gibbs sampling, Section 12.3) can be slow. However, suppose we restrict 
the architecture so that the nodes are arranged in two layers, and so that there are no connections 
between nodes within the same layer (see Figure 4.20(b)). This model is known as a restricted 
Boltzmann machine (RBM) [HT01; HS06a], or a harmonium [Smo86]. The RBM supports 
efficient approximate inference, since the hidden nodes are conditionally independent given the visible 
nodes, i.e., p(z|x) = en p(z~|a). Note this is in contrast to a directed two-layer models, where the 
explaining away effect causes the latent variables to become “entangled” in the posterior even if they 
are independent in the prior. 

Typically the hidden and visible nodes in an RBM are binary, so the energy terms have the form 
Waktazr. If zk = 1, then the k’th hidden unit adds a term of the form wle to the energy; this can 
be thought of as a “soft constraint”. If zg = 0, the hidden unit is not active, and does not have 
an opinion about this data example. By turning on different combinations of constraints, we can 
create complex distributions on the visible data. This is an example of a product of experts 
(Section 24.1.1), since p(æ|z) = J [k:z,=1 exp(wl æ). 

This can be thought of as a mixture model with an exponential number of hidden components, 
corresponding to 2” settings of z. That is, z is a distributed representation, whereas a standard 
mixture model uses a localist representation, where z € {1, K}, and each setting of z corresponds 
to a complete prototype or exemplar wk to which x is compared, giving rise to a model of the form 
plæ|z = k) x exp(w] æ). 

Many different kinds of RBMs have been defined, which use different pairwise potential functions. 
See Table 4.4 for a summary. (Figure 4.21 gives an example of some images generated from an 
RBM fit to the binarized MNIST dataset.) All of these are special cases of the exponential family 
harmonium [WRZH04]. See Supplementary Section 4.3 for more details. 
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Visible Hidden Name Reference 
Binary Binary Binary RBM [HS06a] 
Gaussian Binary Gaussian RBM [WS05] 
Categorical Binary Categorical RBM [SMH07] 
Multiple categorical Binary Replicated softmax/ undirected LDA [SH10] 
Gaussian Gaussian Undirected PCA [MM0]1] 
Binary Gaussian Undirected binary PCA [WS05] 


Table 4.4: Summary of different kinds of RBM. 


(a) (b) 


Figure 4.22: (a) Deep Boltzmann machine. (b) Deep belief network. The top two layers define the prior 
in terms on an RBM. The remaining layers are a directed graphical model that “decodes” the prior into 
observable data. 


4.3.3.3 Deep Boltzmann machines 
We can make a “deep” version of an RBM by stacking multiple layers; this is called a deep Boltzmann 


machine [SH09]. For example, the two layer model in Figure 4.22(a) has the form 


exp (a Wi 21 + zi W222) (4.91) 


1 
p(z, 21, z210) z Z(Wi, Wo) 


where æ are the visible nodes at the bottom, and we have dropped bias terms for brevity. 


35 4.3.3.4 Deep belief networks (DBNs) 


37 We can use an RBM as a prior over a latent distributed code, and then use a DPGM “decoder” to 
convert this into the observed data, as shown in Figure 4.22(b). The corresponding joint distribution 


has the form 


1 
p(x, 21, Z2|0) = p(a|z1, Wi TW) exp (z] W222) (4.92) 


W2) 


43 In other words, it is an RBM on top of a DPGM. This combination has been called a deep belief 


network (DBN) [HOT06a]. However, this name is confusing, since it is not actually a belief net. 


45 We will therefore call it a deep Boltzmann network (which conveniently has the same DBN 
46 abbreviation). 
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DBNs can be trained in a simple greedy fashion, and support fast bottom-up inference (see 
[HOT06a] for details). DBNs played an important role in the history of deep learning, since they 
were one of the first deep models that could be successfully trained. However, they are no longer 
widely used, since the advent of better ways to train fully supervised DNNs (such as using ReLU 
units and the Adam optimizer), and the advent of efficient ways to train deep DPGMs, such as the 
VAE (Section 21.2). 


4.3.4 Maximum entropy models 


In Section 2.3.7, we show that the exponential family is the distribution with maximum entropy, 
subject to the constraints that the expected value of the features (sufficient statistics) @(a@) match 
the empirical expectations. Thus the model has the form 


p(x|8) = 


70) exp (0' d(x) (4.93) 


If the features p(x) decompose according to a graph structure, we get a kind of MRF known as a 
maximum entropy model. We give some examples below. 


4.3.4.1 Log-linear models 


Suppose the potential functions have the following log-linear form: 


We(#e; 0.) = exp(0! be) (4.94) 


where (£e) is a feature vector derived from the variables in clique c. Then the overall model is 
given by 


p(2l0) = 70) exp 2 soe) (4.95) 

For example, in a Gaussian graphical model (GGM), we have 

P([xi, 2,]) = [£i £j, tix] (4.96) 
for x; € R. And in an Ising model, we have 

P([xi, 2,]) = [£i £j, tix] (4.97) 


for x; € {—1, +1}. Thus both of these are maxent models. However, there are two key differences: 
first, in a GGM, the variables are real-valued, not binary; second, in a GGM, the partition function 
Z(@) can be computed in O(D?) time, whereas in a Boltzmann machine, computing the partition 
function can take O(2”) time (see Section 9.4.4 for details). 

If the features @ are structured in a hierarchical way (capturing first order interactions, and second 
order interactions, etc.), and all the variables x are categorical, the resulting model is known in 
statistics as a log-linear model. However, in the ML community, the term “log-linear model” is 
often used to describe any model of the form Equation (4.95). 
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4.3.4.2 Feature induction for a maxent spelling model 


In some applications, we assume the features p(x) are known. However, it is possible to learn the 
features in a maxent model in an unsupervised way; this is known as feature induction. 

A common approach to feature induction, first proposed in [DDL97; ZWM97], is to start with a 
base set of features, and then to continually create new feature combinations out of old ones, greedily 
adding the best ones to the model. 

As an example of this approach, [DDL97]| describe how to build models to represent English 
spelling. This can be formalized as a probability distribution over variable length strings, p(æ|0), 
where 2; is a letter in the English alphabet. Initially the model has no features, which represents the 
uniform distribution. The algorithm starts by choosing to add the feature 


o1(x) = >I (zi € {a,...,z}) (4.98) 


which checks if any letter is lower case or not. After the feature is added, the parameters are (re)-fit 
by maximum likelihood (a computationally difficult problem, which we discuss in Section 4.3.9.1). 
For this feature, it turns out that 6, = 1.944, which means that a word with a lowercase letter in any 
position is about e!-944 ~ 7 times more likely than the same word without a lowercase letter in that 
position. Some samples from this model, generated using (annealed) Gibbs sampling (described in 
Section 12.3), are shown below.°® 


m, r, xevo, ijjiir, b, to, jz, gsr, wq, vf, x, ga, msmGh, pcp, d, oziVlal, hzagh, yzop, io, 
advzmxnv, ijv_bolft, x, emx, kayerf, mlj, rawzyb, jp, ag, ctdnnnbg, wgdw, t, kguv, cy, 
spxcq, uzflbbf, dxtkkn, cxwx, jpd, ztzh, lv, zhpkvnu, 1^, r, qee, nynrx, atze4n, ik, se, W, 
lrh, hpt+, yrqyka’h, zcngotcnx, igcump, zjcjs, lqpWiqu, cefmfhc, o, lb, fdcY, tzby, yopxmvk, 
by, fz,, t, govyccm, ijyiduwfzo, 6xr, duh, ejv, pk, pjw, 1, fl, w 


The second feature added by the algorithm checks if two adjacent characters are lower case: 


palæ) => (wz € {a,..., 2}, 2; € {a,...,2}) (4.99) 


ixj 


“= Now the model has the form 


p(w) = > exp(0r61() + 02622) (4.100) 


3 Continuing in this way, the algorithm adds features for the strings s> and ing>, where > represents 
2. the end of word, and for various regular expressions such as [0-9], etc. Some samples from the 
38 model with 1000 features, generated using (annealed) Gibbs sampling, are shown below. 


was, reaser, in, there, to, will, ,, was, by, homes, thing, be, reloverated, ther, which, 
conists, at, fores, anditing, with, Mr., proveral, the, ,, ***, on’t, prolling, prothere, ,, 
mento, at, yaou, 1, chestraing, for, have, to, intrally, of, qut, ., best, compers, ***, 
cluseliment, uster, of, is, deveral, this, thise, of, offect, inatever, thifer, 
constranded, stater, vill, in, thase, in, youse, menttering, and, ., of, in, verate, of, 
to 


46 6. We thank John Lafferty for sharing this example. 
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If we define a feature for every possible combination of letters, we can represent any probability 
distribution. However, this will overfit. The power of maxent approach is that we can choose which 
features matter for the domain. 

An alternative approach is to introduce latent variables, that implicitly model correlations amongst 
the visible nodes, rather than explicitly having to learn feature functions. See Section 4.3.3 for an 
example of such a model. 


4.3.5 Gaussian MRFs 

In Section 4.2.3, we showed how to represent a multivariate Gaussian using a DPGM. In this section, 
we show how to represent a multivariate Gaussian using an UPGM. (For further details on GMRFs, 
see e.g., [RH05].) 

4.3.5.1 Standard GMRFs 


A Gaussian graphical model (or GGM), also called a Gaussian MRF, is a pairwise MRF of 
the following form: 


plz) = 7 tened we (4.101) 


ing i 
1 
Wij (Vi, £j) = exp(— 5 aihijx;) (4.102) 
1 
ilz) = exp(— 5 Aue; + nizi) (4.103) 
Z(0) = (2n)P/?|A|-2 (4.104) 


The yi; are edge potentials (pairwise terms), each the 7; are node potentials or unary terms. 
(We could absorb the unary terms into the pairwise terms, but we have kept them separate for 
clarity.) 

The joint distribution can be rewritten in a more familiar form as follows: 


1 
p(x) x exp[n' a — z% Aa] (4.105) 


This is called the information form of a Gaussian; A = ©~' and n = Ap are called the canonical 
parameters. 

If A;; = 0 , there is no pairwise term connecting x; and zj, and hence z; L £i|£—ij, where x_;; 
are all the nodes except for x; and xj. Hence the zero entries in A are called structural zeros. 
This means we can use l; regularization on the weights to learn a sparse graph, a method known as 
graphical lasso (see Supplementary Section 30.3.2). 

Note that the covariance matrix X = AT! can be dense even if the precision matrix A is sparse. 
For example, consider an AR(1) process with correlation parameter p.” The precision matrix (for a 


7. This example is from https: //dansblog.netlify.app/posts/2022-03-22-a-linear-mixed-effects-model1/. 
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graph with T = 7 nodes) looks like this: 


1 =P 
p 1+ p 
1 p 1+ p 
A= -p 1+ -p (4.106) 


p P Ë É ČEP 
P pe P Ë É Ber pe 
Ë Ê p P Ë f > 
APE Ë P p P P P (4.107) 
Ë É Ë PÊ p PP 
É Ë É Ë P p P 
E E Č ÉPP 


This follows because, in a chain structured UPGM, every pair of nodes is marginally correlated, even 
if they may be conditionally independent given a separator. 


4.3.5.2 Nonlinear Gaussian MRFs 


In this section, we consider a generalization of GGMs to handle the case of nonlinear models. Suppose 
the joint is given by a product of local factors, or clique potentials, Yc, each of which is defined on a 
set or clique variables a, as follows: 


p(w) = ; ][v-(«-) (4.108) 
Pelze) = exp(—E,(a-)) (4.109) 
E(t.) = 5 (Fel@e) — de) Dz" (fe(@e) — de) (4.110) 


~ where de is an optional local evidence term for the c’th clique, and fe is some measurement function. 


Suppose the measurent function fe is linear, i.e., 
fel) = Jex + be (4.111) 


In this case, the energy for clique c becomes 


1 1 
E.(ae) = she JLE Ie xe + al ITD (be — de) + 5 (be — de)E7' (be — de) (4.112) 
—— e“~_ —__“<“—S 
A. =Ne k 
1 
= ste Neste — Nl ze + ke (4.113) 
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(a) (b) 


Figure 4.23: (a) A DPGM. (b) Its moralized version, represented as a UPGM. 


which is a standard Gaussian factor. If fe is nonlinear, it is common to linearize the model around 
the current estimate x? to get 


fe(@c) © fela?) + Je(£e — L?) = eae + (fo(x®) — Jex?) (4.114) 
— SS ey 

bo 
where J, is the Jacobian of fe(£e) wrt £e. This gives us a “temporary” Gaussian factor that we can 
use for inference. This process can be iterated for improved accuracy. 


4.3.6 Conditional independence properties 


In this section, we explain how UPGMs encode conditional independence assumptions. 


4.3.6.1 Basic results 


UPGMs define CI relationships via simple graph separation as follows: given 3 sets of nodes A, B, 
and C, we say Xa Lc Xp|Xc iff C separates A from B in the graph G. This means that, when we 
remove all the nodes in C, if there are no paths connecting any node in A to any node in B, then 
the CI property holds. This is called the global Markov property for UPGMs. For example, in 
Figure 4.23(b), we have that {X 1, X2} L {X6, X7}|{X3, X4, X5}. 

The smallest set of nodes that renders a node t conditionally independent of all the other nodes in 
the graph is called t’s Markov blanket; we will denote this by mb(t). Formally, the Markov blanket 
satisfies the following property: 


t L VY \ el(t)|mb(t) (4.115) 


A 


where cl(t) = mb(t) U {t} is the closure of node t, and V = {1,..., Na} is the set of all nodes. 
One can show that, in a UPGM, a node’s Markov blanket is its set of immediate neighbors. This 
is called the undirected local Markov property. For example, in Figure 4.23(b), we have 
mb(X5) = {Xo, X3, X4, X6, X7}. 

From the local Markov property, we can easily see that two nodes are conditionally independent 
given the rest if there is no direct edge between them. This is called the pairwise Markov property. 
In symbols, this is written as 


s LtV \{s,t} 4> Ga =0 (4.116) 


where Gst = 0 means there is no edge between s and t (so there is a 0 in the adjaceny matrix). 
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Figure 4.24: Relationship between Markov properties of UPGMs. 


© © 


D © <|> j 


(a) (e) 


Figure 4.25: (a) The ancestral graph induced by the DAG in Figure 4.23(a) wrt U = {X2, X4, X5}. (b) The 
moralized version of (a). 


Using the three Markov properties we have discussed, we can derive the following CI properties 
(amongst others) from the UPGM in Figure 4.23(b): Xı L Xz|rest (pairwise); Xı L rest| X2, X3 
(local); X1, X2 L X6, X7|X3, X4, Xs (global). 

It is obvious that global Markov implies local Markov which implies pairwise Markov. What is less 
obvious is that pairwise implies global, and hence that all these Markov properties are the same, 


28 as illustrated in Figure 4.24 (see e.g., |KF09a, p119] for a proof). The importance of this result is 


that it is usually easier to empirically assess pairwise conditional independence; such pairwise CI 
statements can be used to construct a graph from which global CI statements can be extracted. 


$2 4.3.6.2 An undirected alternative to d-separation 


We have seen that determinining CI relationships in UPGMs is much easier than in DPGMs, because 
we do not have to worry about the directionality of the edges. That is, we can use simple graph 


separation, instead of d-separation. 


In this section, we show how to convert a DPGM to a UPGM, so that we can infer CI relationships 
for the DPGM using simple graph separation. It is tempting to simply convert the DPGM to a 
UPGM by dropping the orientation of the edges, but this is clearly incorrect, since a v-structure 
A — B + C has quite different CI properties than the corresponding undirected chain A — B — C 
(e.g., the latter graph incorrectly states that A L C|B). To avoid such incorrect CI statements, 
we can add edges between the “unmarried” parents A and C, and then drop the arrows from the 


8. This assumes p(x) > 0 for all æ, i.e., that p is a positive density. The restriction to positive densities arises because 


— deterministic constraints can result in independencies present in the distribution that are not explicitly represented in 


= the graph. See e.g., [KF09a, p120] for some examples. Distributions with non-graphical CI properties are said to be 
46 unfaithful to the graph, so I(p) 4 I(G). 
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Zi Z2 
T2 


23 Z4 


3@ T4 


Figure 4.26: A grid-structured MRF with hidden nodes z; and local evidence nodes xi. The prior p(z) is an 
undirected Ising model, and the likelihood p(æ|z) = Į] ]; p(xı|zi) is a directed fully factored model. 


edges, forming (in this case) a fully connected undirected graph. This process is called moralization. 
Figure 4.23 gives a larger example of moralization: we interconnect 2 and 3, since they have a 
common child 5, and we interconnect 4, 5 and 6, since they have a common child 7. 

Unfortunately, moralization loses some CI information, and therefore we cannot use the moralized 
UPGM to determine CI properties of the DPGM. For example, in Figure 4.23(a), using d-separation, 
we see that X4 L X5|X2. Adding a moralization arc X4 — X5 would lose this fact (see Figure 4.23(b)). 
However, notice that the 4-5 moralization edge, due to the common child 7, is not needed if we do not 
observe 7 or any of its descendants. This suggests the following approach to determining if A L B|C. 
First we form the ancestral graph of DAG G with respect to U = AU BUC. This means we 
remove all nodes from G that are not in U or are not ancestors of U. We then moralize this ancestral 
graph, and apply the simple graph separation rules for UPGMs. For example, in Figure 4.25(a), we 
show the ancestral graph for Figure 4.23(a) using U = {X2, X4, Xs}. In Figure 4.25(b), we show the 
moralized version of this graph. It is clear that we now correctly conclude that X4 L X5|X2. 


4.3.7 Generation (sampling) 


Unlike with DPGMs, it can be quite slow to sample from an UPGM, even from the unconditional 
prior, because there is no ordering of the variables. Furthermore, we cannot easily compute the 
probability of any configuration unless we know the value of Z. Consequently it is common to use 
MCMC methods for generating from an UPGM (see Chapter 12). 

In the special case of UPGMs with low treewidth and discrete or Gaussian potentials, it is 
possible to use the junction tree algorithm to draw samples using dynamic programming (see 
Supplementary Section 9.2.3). 


4.3.8 Inference 


We discuss inference in graphical models in detail in Chapter 9. In this section, we just give an 
example. 
Suppose we have an image composed of binary pixels, z;, but we only observe noisy versions of the 
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sample 1, meanfieldH mean after 15 sweeps of meanfieldH 


(b) 


Figure 4.27: Example of image denoising using mean field variational inference. We use an Ising prior with 
Wij = 1 and a Gaussian noise model with o = 2. (a) Noisy image. (b) Result of inference. Generated by 
ising image denoise_ demo.ipynb. 


pixels, z;. We assume the joint model has the form 
1 
p(x, z) = p(z)p(2|z) = 7 XO big (zis zi) J [pez (4.117) 


where p(z) is an Ising model prior, and p(x;|z;) = N (zx;|zi, 07), for zi € {—1, +1}. This model uses a 
UPGM as a prior, and has directed edges for the likelihood, as shown in Figure 4.26; such a hybrid 
undirected-directed model is called a chain graph (even though it is not chain-structured). 

The inference task is to compute the posterior marginals p(z;|æ), or the posterior MAP estimate, 
argmax, p(z|x). The exact computation is intractable for large grids (for reasons explained in 
Section 9.4.4), so we must use approximate methods. There are many algorithms that we can use, 
including mean field variational inference (Section 10.2.2), Gibbs sampling (Section 12.3.3), loopy 
belief propagation (Section 9.3), etc. In Figure 4.27, we show the results of variational inference. 


31 4.3.9 Learning 


In this section, we discuss how to estimate the parameters for an MRF. As we will see, computing 
the MLE can be computationally expensive, even in the fully observed case, because of the need to 
deal with the partition function Z(0). And computing the posterior over the parameters, p(0@|D), 
is even harder, because of the additional normalizing constant p(D) — this case has been called 
doubly intractable [MGM06]. Consequently we will focus on point estimation methods such as 
MLE and MAP. (For one approach to Bayesian parameter inference in an MRF, based on persistent 
variational inference, see [IM17].) 


4.3.9.1 Learning from complete data 


43 We will start by assuming there are no hidden variables or missing data during training (this is 


known as the complete data setting). For simplicity of presentation, we restrict our discusssion to 


45 the case of MRFs with log-linear potential functions. (See Section 24.2 for the general nonlinear case, 
46 where we discuss MLE for energy-based models.) 
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In particular, we assume the distribution has the following form: 


pl2l®) = zgj exp (= sro.) (4.118) 


where c indexes the cliques. The (averaged) log-likelihood of the full dataset becomes 


(0) £ x Slog p(@n|8) = S 5 z 0", (an) — log Z(8) (4.119) 


n c 


Its gradient is given by 
oe 1 o 
36, N 3 feeen) - 59, 28 z(0) (4.120) 


We know from Section 2.3.3 that the derivative of the log partition function wrt Oe is the expectation 
of the c’th feature vector under the model, i.e., 


ree = E[¢,(@)|6] = Y` p(a|0)o.(«) HU 


Hence the gradient of the log likelihood is 


Fp = WL leen) -Elle (4.122) 


When the expected value of the features according to the data is equal to the expected value of 
the features according to the model, the gradient will be zero, so we get 


This is called moment matching. Evaluating the Epp [P.(Œ)] term is called the clamped phase 
or positive phase, since g is set to the observed values æn; evaluating the Ep(æ]o) [Qe (x)] term is 
called the unclamped phase or negative phase, since æ is free to vary, and is generated by the 
model. 

In the case of MRFs with tabular potentials (i.e., one feature per entry in the clique table), we 
can use an algorithm called iterative proportional fitting or IPF [Fie70; BFH75; JP95] to solve 
these equations in an iterative fashion.? But in general, we must use gradient methods to perform 
parameter estimation. 


9. In the case of decomposable graphs, IPF converges in a single iteration. Intuitively, this is because a decomposable 
graph can be converted to a DAG without any loss of information, as explained in Section 4.5, and we know that we 
can compute the MLE for tabular CPDs in closed form, just by normalizing the counts. 
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Figure 4.28: (a) A small 2d lattice. (b) The representation used by pseudo likelihood. Solid nodes are observed 
neighbors. Adapted from Figure 2.2 of [Car03]. 


4.3.9.2 Computational issues 


The biggest computational bottleneck in fitting MRFs and CRFs using MLE is the cost of computing 
the derivative of the log partition function, log 7(@), which is needed to compute the derivative of 
the log likelihood, as we saw in Section 4.3.9.1. To see why this is slow to compute, note that 


Vo log Z(@) = = ZH | p(x; O)dx = za | Vor (a; 0)d (4.124) 
= eaa On log p(x; 0)dax (4.125) 
= Ez~p(a:0) [Vo log p(x; 0)] (4.126) 


where in Equation (4.125) we used the fact that Ve log (x; 0) = x. amy V oP (2; 0) (this is known as 
the log-derivative trick). Thus we see that we need to draw samples from the model at each step 
of SGD training, just to estimate the gradient 

In Section 24.2.1, we discuss various efficient sampling methods. However, it is also possible to 


= use alternative estimators which do not use the principle of maximum likelihood. For example, in 


Section 24.2.2 we discuss the technique of contrastive divergence. And in Section 4.3.9.3, we discuss 


= the technique of pseudo likelihood. (See also [Stol7] for a review of many methods for parameter 
“= estimation in MRFs.) 


4.3.9.3 Maximum pseudo-likelihood estimation 


36 When fitting fully visible MRFs (or CRFs), a simple alternative to maximizing the likelihood is to 
37 maximize the pseudo likelihood [Bes75], defined as follows: 


lpr (0 0) £ a a Si Lna|Ln, —d: 8) (4.127) 


D a=1d=1 


42 That is, we optimize the product of the full conditionals, also known as the composite likelihood 
43 [Lin88a; DL10; VRF11]. Compare this to the objective for maximum likelihood: 


lur(@) = —— i Sok En|O) (4.128) 
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4.3. UNDIRECTED GRAPHICAL MODELS (MARKOV RANDOM FIELDS) 


In the case of Gaussian MRFs, PL is equivalent to ML [Bes75], although this is not true in general. 
Nevertheless, it is a consistent estimator in the large sample limit [LJ08]. 

The PL approach is illustrated in Figure 4.28 for a 2d grid. We learn to predict each node, given all 
of its neighbors. This objective is generally fast to compute since each full conditional p(x g|x_a, 0) 
only requires summing over the states of a single node, xq, in order to compute the local normalization 
constant. The PL approach is similar to fitting each full conditional separately, except that, in PL, 
the parameters are tied between adjacent nodes. 

Experiments in [PW05; HT09] suggest that PL works as well as exact ML for fully observed Ising 
models, but is much faster. In [Eke+13], they use PL to fitt Potts models to (aligned) protein 
sequence data. However, when fitting RBMs, [Mar+10] found that PL is worse than some of the 
stochastic ML methods we discuss in Section 24.2. 

Another more subtle problem is that each node assumes that its neighbors have known values 
during training. If node j € nbr(i) is a perfect predictor for node i (where nbr(i) is the set of 
neighbors), then j will learn to rely completely on node i, even at the expense of ignoring other 
potentially useful information, such as its local evidence, say y;. At test time, the neighboring nodes 
will not be observed, and performance will suffer.!° 


4.3.9.4 Learning from incomplete data 


In this section, we consider parameter estimation for MRFs (and CRFs) with hidden variables. Such 
incomplete data can arise for several reasons. For example, we may want to learn a model of 
the form p(z)p(a|z) which lets us infer a “clean” image z from a noisy or corrupted version æ. If 
we only observe x, the model is called a hidden Gibbs random field. See Section 10.2.2 for an 
example. As another example, we may have a CRF in which the hidden variables are used to encode 
an unknown alignment between the inputs and outputs [Qua-+07], or to model missing parts of the 
input [SRS10]. 

We now discuss how to compute the MLE in such cases. For notational simplicity, we focus on 
unconditional models (MRFs, not CRFs), and we assume all the potentials are log-linear. In this 
case, the model has the following form: 


exp(O'b(@.2)) _ Ble. 2/0) 
rm 210) a SZ) 


Z(0) = 5 exp(6" (a, z)) (4.130) 


(4.129) 


where p(x, 2|@) is the unnormalized distribution. We have dropped the sum over cliques c for brevity. 


10. Geoff Hinton has an analogy for this problem. Suppose we want to learn to denoise images of symmetric shapes, 
such as Greek vases. Each hidden pixel x; depends on its spatial neighbors, as well the noisy observation y;. Since its 
symmetric counterpart xj will perfectly predict x;, the model will ignore y; and just rely on xj, even though x; will 
not be available at test time. 
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The log likelihood is now given by 


N 
(0) = N L Sio «(Sa Ln, Zn|9) (4.131) 
1 A 
= WP «(ata 8) dP Ln, Zn|9) (4.132) 
1 N 
=W » [con] — log Z(8) (4.133) 
Note that 
log X B(@n, 2n|0) = log X exp(6"b(@n, zn)) Ê log Z(8, an) (4.134) 


where Z(0, £n) is the same as the partition function for the whole model, except that x is fixed at 
£n. Thus the log likelihood is a difference of two partition functions, one where x is clamped to £n 
and z is unclamped, and one where both x and z are unclamped. The gradient of these log partition 
functions corresponds to the expected features, where (in the clamped case) we condition on £ = £n. 
Hence 


ae ı 
50 = N 2 Bexpizlen.8) [P(@n; 2)]] - Eve,2)~n(e.210) Ple, 2)] (4.135) 


4.4 Conditional random fields (CRFs) 


A conditional random field or CRF [LMP01] is a Markov random field defined on a set of related 


zn label nodes y, whose joint probability is predicted conditional on a fixed set of input nodes a. More 
5, precisely, it corresponds to a model of the following form: 


p(y|x, 0) = Few [ewe e) (4.136) 


(Note how the partition function now depends on the inputs x as well as the parameters 0.) Now 
suppose the potential functions are log-linear and have the form 


YPelYci £, 0) = exp(O2 pe(2, Yc)) (4.137) 


This is a conditional version of the maxent models we discussed in Section 4.3.4. Of course, we can 


41 also use nonlinear potential functions, such as DNNs. 


CRFs are useful because they capture dependencies amongst the output labels. They can therefore 
be used for structured prediction, where the output y € Y that we want to predict given the 
input æ lives in some structured space, such as a sequence of labels, or labels associated with nodes 


45 on a graph. In such problems, there are often constraints on the set of valid values of the output y. 
46 For example, if we want to perform sentence parsing, the output should satisfy the rules of grammar 
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4.4. CONDITIONAL RANDOM FIELDS (CRFS) 


Yt-1 Yt+1 
Ut 


Lt-1 Xt Lt+1 
Figure 4.29: A 1d conditional random field (CRF) for sequence labeling. 


(e.g., noun phrase must precede verb phrase). See Section 4.4.1 for details on the application of CRFs 
to NLP. In some cases, the “constraints” are “soft”, rather than “hard”. For example, if we want to 
associate a label with each pixel in an image (a task called semantic segmentation), we might want 
to “encourage” the label at one location to be the same as its neighbors, unless the visual input 
strongly suggests a change in semantic content at this location (e.g., at the edge of an object). See 
Section 4.4.2 for details on the applications of CRFs to computer vision tasks. 


4.4.1 1d CREs 


In this section, we focus on 1d CRFs defined on chain-structured graphical models. The graphical 
model is shown in Figure 4.29. This defines a joint distribution over sequences, y}.7, given a set of 
inputs, 21.7, as follows: 


1 T T 


plyrrle, 0) = Zæ 0) [[ ve. ve: 9) ] J Vuy) (4.138) 


=1 t=2 


where Y (y+, £4; 0) are the node potentials and Y (yt, yt++1; 0) are the edge potentials. (We have assumed 
that the edge potentials are independent of the input x, but this assumption is not required.) 

Note that one could also consider an alternative way to define this conditional distribution, by 
using a discriminative directed Markov chain: 


T 
plyrrlæ, 0) = plyilæ1; 0) | | pulu; 21; 8) (4.139) 


t=2 


This is called a maximum entropy Markov model [MFP00]. However, it suffers from a subtle 
flaw compared to the CRF. In particular, in the directed model, each conditional p(y:|yz-1, £t; 9), is 
locally normalized, whereas in the CRF, the model is globally normalized due to the Z(æ, 0) 
term. The latter allows information to propagate through the entire sequence, as we discuss in more 
detail in Section 4.5.3. 

CRFs were widely used in the natural language processing (NLP) community in the 1980s—2010s 
(see e.g., [Smil1]), although recently they have been mostly replaced by RNNs and transformers (see 
e.g., [Goll7]). Fortunately, we can get the best of both words by combining CRFs with DNNs, which 
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British Airways rose after announcing its withdrawal from the UAL deal 
KEY 
B Begin noun phrase V Verb 
I Within noun phrase IN Preposition 
O Not a noun phrase PRP Possesive pronoun 
N Noun DT Determiner (e.g., a, an, the) 


ADJ Adjective 


Figure 4.30: A CRF for joint part of speech tagging and noun phrase segmentation. From Figure 4.E.1 of 
[KF09a]. Used with kind permission of Daphne Koller. 


allows us to combine data driven techniques with prior knowledge about constraints on the label 
space. We give some examples below. 


4.4.1.1 Noun phrase chunking 


A common task in NLP is information extraction, in which we try to parse a sentence into noun 
phrases (NP), such as names and addresses of people or businesses, as well as verb phrases, which 
describe who is doing what to whom (e.g., “British Airways rose”). In order to tackle this task, we 
can assign a part of speech tag to each word, where the tags correspond to Noun, Verb, Adjective, 
etc. In addition, to extract the span of each noun phrase, we can annotate words as being at the 
beginning (B) or inside (I) a noun phrase, or outside (O) of one. See Figure 4.30 for an example. 
The connections between adjacent labels can encode constraints such as the fact that B (begin) 


31 must preceed I (inside). For example, the sequences OBIIO and OBIOBIO are valid (corresponding to 
32 one NP of 3 words, and two adjacent NPs of 2 words), but OIBIO is not. This prior information can 
33 be encoded by defining (yP!? = x, yP1O = B, x4; 0) to be 0 for any value of * except O. We can 


encode similar grammatical rules for the POS tags. 
Given this model, we can compute the MAP sequence of labels, and thereby extract the spans 


36 that are labeled as noun phrases. This is called noun phrase chunking. 


4.4.1.2 Named entity recognition 


In this section we consider the task of named entity extraction, in which we not only tag the 


41 noun phrases, but also classify them into different types. A simple approach to this is to extend 
42 the BIO notation to {B-Per, I-Per, B-Loc, I-Loc, B-Org, I-Org, Other }. However, sometimes it is 
43 ambiguous whether a word is a person, location, or something else. Proper nouns are particularly 


difficult to deal with because they belong to an open class, that is, there is an unbounded number 


45 of possible names, unlike the set of nouns and verbs, which is large but essentially fixed. For example, 
46 “British Airways” is an organization, but “British Virgin Islands” is a location. 
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4.4. CONDITIONAL RANDOM FIELDS (CRFS) 


I-PER B-LOC, I-LOC B-PER OTH 


Mrs. Green spoke today in New York Green chairs the finance committee 


KEY 


B-PER Begin person name I-LOC Within location name 
I-PER Within person name OTH Notan entitiy 
B-LOC Begin location name 


Figure 4.81: A skip-chain CRF for named entity recognition. From Figure 4.E.1 of [KF 09a]. Used with kind 
permission of Daphne Koller. 


We can get better performance by considering long-range correlations between words. For example, 
we might add a link between all occurrences of the same word, and force the word to have the same 
tag in each occurence. (The same technique can also be helpful for resolving the identity of pronouns.) 
This is known as a skip-chain CRF. See Figure 4.31 for an illustration, where we show that the 
word “Green” is interpeted as a person in both occurrences within the same sentence. 

We see that the graph structure itself changes depending on the input, which is an additional 
advantage of CRFs over generative models. Unfortunately, inference in this model is generally 
more expensive than in a simple chain with local connections because of the larger treewdith (see 
Section 9.4.2). 


4.4.1.3 Natural language parsing 


A generalization of chain-structured models for language is to use probabilistic grammars. In 
particular, a probabilistic context free grammar or PCFG in Chomsky normal form is a set of 
re-write or production rules of the form o > o'o” or o —> x, where o,0’,0” € © are non-terminals 
(analogous to parts of speech), and x € Æ are terminals, i.e., words. Each such rule has an associated 
probability. The resulting model defines a probability distribution over sequences of words. We can 
compute the probability of observing a particular sequence x = 71 ...x27 by summing over all trees 
that generate it. This can be done in O(T*) time using the inside-outside algorithm; sce e.g., 
[JM08; MS99; Eis16] for details. 

PCFGs are generative models. It is possible to make discriminative versions which encode 
the probability of a labeled tree, y, given a sequence of words, x, by using a CRF of the form 
p(y|x) x exp(w' W(x, y)). For example, we might define U(x, y) to count the number of times each 
production rule was used (which is analogous to the number of state transitions in a chain-structured 
model), as illustrated in Figure 4.32. We can also use a deep neural net to define the features, as in 
the neural CRF parser method of [DK15b]. 
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X| The dog chased the cat g 5 "a ps 
f:X Yd 2|NP > Det N 
5 1| VP—>V NP 
| S : 
yR Yay) =| 
Ne 0| Det > dog 
D 2| Det — the 
Det N V De N 1| N—dog 
4 4 i J | 1|V —> chased 
The dog chased the cat 1) N>cat 


Figure 4.32: Illustration of a simple parse tree based on a context free grammar in Chomsky normal form. 
The feature vector Y(x, y) counts the number of times each production rule was used, and is used to define the 


energy of a particular tree structure, E(y|x) = —w'Y(x,y). The probability distribution over trees is given 
by ply|x) x exp(—E(y|a)). From Figure 5.2 of [AHTO07]. Used with kind permission of Yasemin Altun. 
yı Y2 
T2 
Y3 Y4 


Figure 4.33: A grid-structured CRF with label nodes yi and local evidence nodes xi. 


22 4.4.2 2d CRFs 


31 It is also possible to apply CRFs to image processing problems, which are usually defined on 2d 


grids, as illustrated in Figure 4.33. (Compare this to the generative model in Figure 4.26.) This 
corresponds to the following conditional model: 


rule) = zr E vislviyd| [Leslee (4.140) 


ing 


2° In the sections below, we discuss some applications of this and other CRF models in computer vision. 


= 4.4.2.1 Semantic segmentation 


42 The task of semantic segmentation is to assign a label to every pixel in an image. We can easily 
43 solve this problem using a CNN with one softmax output node per pixel. However, this may fail to 
44 capture long-range dependencies, since convolution is a local operation. 


One way to get better results is to feed the output of the CNN into a CRF. Since the CNN already 


46 uses convolution, its outputs will usually already be locally smooth, so the benefits from using a CRF 
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4.4. CONDITIONAL RANDOM FIELDS (CRFS) 


Aeroplane 
Coarse Score map 
Deep 
> Convolutional a 
Neural 
Network 
v 
Final Output Fully Connected CRF Bi-linear Interpolation 


ONIN 


d r, 


= — 


Figure 4.34: A fully connected CRF is added to the output of a CNN, in order to increase the sharpness of 
the segmentation boundaries. From Figure 3 of [Che+15]. Used with kind permission of Jay Chen. 


with a local grid structure may be quite small. However, we can somtimes get better results if we 
use a fully connected CRF, which has connections between all the pixels. This can capture long 
range connections which the grid-structured CRF cannot. See Figure 4.34 for an illustration, and 
[Che+ 17a] for details. 

Unfortunately, exact inference in a fully connected CRF is intractable, but in the case of Gaus- 
sian potentials, it is possible to devise an efficient mean field algorithm, as described in [KK11]. 
Interestingly, [Zhe+15] showed how the mean field update equations can be implemented using a 
recurrent neural network (see Section 16.3.4), allowing end-to-end training. Alternatively, if we are 
willing to use a finite number of iterations, we can just “unroll” the computation graph and treat 
it as a fixed-sized feedforward circuit. The result is a graph-structured neural network, where the 
topology of the GNN is derived from the graphical model (c.f., Section 9.3.10). The advantage of 
this compared to standard CRF methods is that we can train this entire model end-to-end using 
standard gradient descent methods; we no longer have to worry about the partition function (see 
Section 4.4.3), or the lack of convergence that can arise when combining approximate inference with 
standard CRF learning. 


4.4.2.2 Deformable parts models 


Consider the problem of object detection, i.e., finding the location(s) of an object of a given class 
(e.g., a person or a car) in an image. One way to tackle this is to train a binary classifier that takes 
as input an image patch and specifies if the patch contains the object or not. We can then apply this 
to every image patch, and return the locations where the classifier has high confidence detections; 
this is known as a sliding window detector, and works quite well for rigid objects such as cars 
or frontal faces. Such an approach can be made efficient by using convolutional neural networks 
(CNNs); see Section 16.3.2 for details. 

However, such methods can work poorly when there is occlusion, or when the shape is deformable, 
such as a person’s or animal’s body, because there is too much variation in the overall appearance. 
A natural strategy to deal with such problems is break the object into parts, and then to detect 
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Figure 4.35: Pictorial structures model for a face and body. Each body part corresponds to a node in the 
CRF whose state space represents the location of that part. The edges (springs) represent pairwise spatial 
constraints. The local evidence nodes are not shown. Adapted from a figure by Pedro Felzenszwalb. 


each part separately. But we still need to enforce spatial coherence of the parts. This can be done 
using a pairwise CRF, where node y; specifies the location of part i in the image (assuming it is 
present), and where we connect adjacent parts by a potential function that encourages them to be 
close together. For example, we can use a pairwise potential of the form W(y;, y;|x) = exp(—d(yi, y;)), 
where y; € {1,..., K} is the location of part i (a discretization of the 2d image plane), and d(yi, yi) 
is the distance between parts i and j. (We can make this “distance” also depend on the inputs æ if 
we want, for example we may relax the distance penalty if we detect an edge.) In addition we will 
have a local evidence term of the form p(y;|#), which can be any kind of discriminative classifier, 
such as a CNN, which predicts the distribution over locations for part i given the image a. The 
overall model has the form 


Tiroaren] I] eile) (4.141) 


1 
p(y|x) = Z(a) 
(i,j)EB 


31 where F is the set of edges in the CRF, and f(x); is the tth output of the CNN. 


We can think of this CRF as a series of parts connected by springs, where the energy of the system 
increases if the parts are moved too far from their expected relative distance. This is illustrated in 
Figure 4.35. The resulting model is known as a pictorial structure [FE73], or deformable parts 


35 model [Fel+ 10]. Furthermore, since this is a conditional model, we can make the spring strengths 
36 be image dependent. 


We can find the globally optimal joint configuration y* = argmax, p(y|x,@) using brute force 


38 enumeration in O(K7) time, where T is the number of nodes and K is the number of states (locations) 


per node. While T is often small, (e.g., just 10 body parts in Figure 4.35), K is often very large, 
since there are millions of possible locations in an image. By using tree-structured graphs, exact 


41 inference can be done in O(T K?) time, as we explain in Section 9.2.2. Furthermore, by exploiting 
42 the fact that the discrete states are ordinal, inference time can be further reduced to O(T K), as 
43 explained in [Fel+-10]. 


Note that by “augmenting” standard deep neural network libaries with a dynamic programming 


45 inference “module”, we can represent DPMs as a kind of CNN, as shown in [Gir+15]. The key 
46 property is that we can backpropagate gradients through the inference algorithm. 
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4.4. CONDITIONAL RANDOM FIELDS (CRFS) 


4.4.3 Parameter estimation 


In this section, we discuss how to perform maximum likelihood estimation for CRFs. This is a small 
extension of the MRF case in Section 4.3.9.1. 


4.4.3.1 Log-linear potentials 


In this section we assume the log potential functions are linear in the parameters, i.e., 


YPelYci £, 0) = exp(O2 h.(x, Yc)) (4.142) 


Hence the log likelihood becomes 


(0) £ 1 S logplynlæn, 0) = >> NO 05 bc(Ynes@n) — log Z(an; 0) (4.143) 
where 
Z(an;0) =X exp(6" ely, £n)) (4.144) 
y 


is the partition function for example n. 
We know from Section 2.3.3 that the derivative of the log partition function yields the expected 
sufficient statistics, so the gradient of the log likelihood can be written as follows: 


ae 1 a 
a0, N 2 feelo æn) = 99, 08 ane) (4.145) 
1 
= 7 2 [Pe(Yne: 2n) — Ep(ylen.6) [Pe(¥, #n)]] (4.146) 


n 


Since the objective is convex, we can use a variety of solvers to find the MLE, such as the 
stochastic meta descent method of [Vis+-06], which is a variant of SGD where the stepsize is adapted 
automatically. 


4.4.3.2 General case 


In the general case, a CRF can be written as follows: 


9) — EPUL, y:0)) _ _ expl f(z, y: 8) 
Pyles 8) = a) Sy. exp( Fey's) a 


where f(x,y; 0) is a scoring (negative energy) function, where high scores correspond to probable 
configurations. The gradient of the log likelihood is 


N 
1 
Vol(8) = 5 XO Vof (an, Yn; 0) — Vo log Z(an; 0) (4.148) 


n=1 
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Probabilistic Models 


Graphical Models 


Directed Undirected 


Figure 4.36: DPGMs and UPGMs can perfectly represent different sets of distributions. Some distributions 
can be perfectly represented by either DPGM’s or UPGMs; the corresponding graph must be chordal. 


Computing derivatives of the log partition function is tractable provided we can compute the 
corresponding expectations, as we discuss in Section 4.3.9.2. Note, however, that we need to compute 
these derivatives for every training example, which is slower than the MRF case, where the log 
partition function is a constant independent of the observed data (but dependent on the model 
parameters). 


4.4.4 Other approaches to structured prediction 


Many other approaches to structured prediction have been proposed, going beyond CRFs. For 
example, max margin Markov networks [TGK03], and the closely relayed structural support 
vector machine [Tso+05], can be seen as non-probabilistic alternatives to CRFs. More recently, 
[BYM17] proposed Structured Prediction Energy Networks, which are a form of energy based 
model (Chapter 24), where we predict using an optimization procedure, g(a) = argmin€(a, y). 
In addition, it is common to use graph neural networks (Section 16.3.6) and sequence-to-sequence 
models such as transformers (Section 16.3.5) for this task. 


32.4.5 Comparing directed and undirected PGMs 


In this section, we compare DPGMs and UPGMs in terms of their modeling power, we discuss how 


35 to convert from one to the other, and we and present a unified representation. 


— 4.5.1 CI properties 


39 Which model has more “expressive power’, a DPGM or a UPGM? To formalize the question, recall 


from Section 4.2.4 that G is an I-map of a distribution p if I(G) C I(p), meaning that all the CI 


41 statements encoded by the graph G are true of the distribution p. Now define G to be perfect map 
42 of p if I(G) = I(p), in other words, the graph can represent all (and only) the CI properties of the 
43 distribution. It turns out that DPGMs and UPGMs are perfect maps for different sets of distributions 


(see Figure 4.36). In this sense, neither is more powerful than the other as a representation language. 
As an example of some CI relationships that can be perfectly modeled by a DPGM but not a 


46 UPGM, consider a v-structure A > C + B. This asserts that A L B, and A £ B|C. If we drop 
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4.5. COMPARING DIRECTED AND UNDIRECTED PGMS 


a S it 


(a) (b) (c) 


Figure 4.37: A UPGM and two failed attempts to represent it as a DPGM. From Figure 3.10 of [KF09a]. 
Used with kind permission of Daphne Koller. 


the arrows, we get A — C — B, which asserts A L B|C and A X B, which is not consistent with 
the independence statements encoded by the DPGM. In fact, there is no UPGM that can precisely 
represent all and only the two CI statements encoded by a v-structure. In general, CI properties in 
UPGMs are monotonic, in the following sense: if A L B|C, then A L B|(C U D). But in DPGMs, 
CI properties can be non-monotonic, since conditioning on extra variables can eliminate conditional 
independencies due to explaining away. 

As an example of some CI relationships that can be perfectly modeled by a UPGM but not a 
DPGM, consider the 4-cycle shown in Figure 4.37 (a). One attempt to model this with a DPGM is 
shown in Figure 4.37(b). This correctly asserts that A L C|B, D. However, it incorrectly asserts 
that B L D|A. Figure 4.37(c) is another incorrect DPGM: it correctly encodes A L C|B, D, but 
incorrectly encodes B L D. In fact there is no DPGM that can precisely represent all and only the 
CI statements encoded by this UPGM. 

Some distributions can be perfectly modeled by either a DPGM or a UPGM; the resulting graphs 
are called decomposable or chordal. Roughly speaking, this means the following: if we collapse 
together all the variables in each maximal clique, to make “mega-variables”, the resulting graph will 
be a tree. Of course, if the graph is already a tree (which includes chains as a special case), it will 
already be chordal. 


4.5.2 Converting between a directed and undirected model 


Although DPGMs and UPGMs are not in general equivalent, if we are willing to allow the graph to 
encode fewer CI properties than may strictly hold, then we can safely convert one to the other, as we 
explain below. 


4.5.2.1 Converting a DPGM to a UPGM 


We can easily convert a DPGM to a UPGM as follows. First, any “unmarried” parents that share a 
child must get “married”, by adding an edge between them; this process is known as moralization. 
Then we can drop the arrows, resulting in an undirected graph. The reason we need to do this is to 
ensure that the CI properties of the UGM match those of the DGM, as explained in Section 4.3.6.2. 
It also ensures there is a clique that can “store” the CPDs of each family. 
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Figure 4.88: Left: The full student DPGM. Right: the equivalent UPGM. We add moralization arcs D-I, G-J 
and L-S. Adapted from Figure 9.8 of [KF09a]. 


Let us consider an example from [KF 09a]. We will use the (full version of the student network 
shown in Figure 4.38(a). The corresponding joint has the following form: 


P(C,D,1,G,S,L, J, H) (4.149) 
= P(C)P(D|C)P(I)P(GI|I, D) P(S|ID)P(L|G)P(J|L, S)P(H|G, J) (4.150) 


Next, we define a potential or factor for every CPD, yielding 


D(C, D,1,G, 8S, L, J, H) = Yo(C)yp(D, Chr yalG, I, D) (4.151) 
Ys(S, Dvr (L, G)pJ(J, L, S)Vu (H,G, J) (4.152) 


31 All the potentials are locally normalized, since they are CPDs, there is no need for a global 


normalization constant, so Z = 1. The corresponding undirected graph is shown in Figure 4.38(b). 


33 We see the that we have added D-I, G-J, and L-S moralization edges.!! 


35 4.5.2.2 Converting a UPGM to a DPGM 
— To convert a UPGM to a DPGM, we proceed as follows. For each potential function w(x; 0e), we 


create a “dummy node”, call it Y., which is “clamped” to a special observed state, call it y=. We 


— then define p(Y. = yž|£e) = w-(a-; 0e). This “local evidence” CPD encodes the same factor as in the 


DGM. The overall joint has the form punair(x) X pair (x, y“). 
As an example, consider the UPGM in Figure 4.39(a), which defines the joint p(A, B,C, D) = 


— (A, B,C,D)/Z. We can represent this as a DPGM by adding a dummy €E node, which is a child 
~ of all the other nodes. We set Æ = 1 and define the CPD p(E = 1|A, B,C, D) x ¥(A, B,C, D). By 


~~ conditioning on this observed child, all the parents become dependent, as in the UGM. 


=" 11. We will see this example again in Section 9.4, where we use it to illustrate the variable elimination inference 


46 algorithm. 
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4.5. COMPARING DIRECTED AND UNDIRECTED PGMS 


(a) (b) 


Figure 4.89: (a) An undirected graphical model. (b) A directed equivalent, obtained by adding a dummy 
observed child node. 


Ye-1 Yt Yt-1 Yt+1 


Lt] Xt Lt+1 
Lt-1 Xt Lt+1 


(a) (b) 


Figure 4.40: Two discriminative models for sequential data. (a) An undirected model (CRF). (b) A directed 
model (MEMM). 


4.5.3 Conditional directed vs undirected PGMs and the label bias problem 


Directed and undirected models behave somewhat differently in the conditional (discriminative) 
setting. As an example of this, let us compare the 1d undirected CRF in Figure 4.40a with the 
directed Markov chain in Section 4.5.3. (This latter model is called a maximum entropy Markov model 
(MEMM), which is a reference to the connection with maxent models discussed in Section 4.3.4.) 
The MEMM suffers from a subtle problem compared to the CRF known (rather obscurely) as the 
label bias problem [LMPO1]. The problem is that local features at time t do not influence states 
prior to time t. That is, y+—1 L x;|y;, thus blocking information flow backwards in time. 

To understand what this means in practice, consider the part of speech tagging task which we 
discussed in Section 4.4.1.1. Suppose we see the word “banks”; this could be a verb (as in “he banks 
at Chase”), or a noun (as in “the river banks were overflowing”). Locally the part of speech tag for 
the word is ambiguous. However, suppose that later in the sentence, we see the word “fishing”; this 
gives us enough context to infer that the sense of “banks” is “river banks” and not “financial banks”. 
However, in an MEMM the “fishing” evidence will not flow backwards, so we will not be able to infer 
the correct label for “banks”. The CRF does not have this problem. 
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Figure 4.41: A grid-structured MRF with hidden nodes x; and local evidence nodes yi. The prior p(x) is an 
undirected Ising model, and the likelihood p(y|x) =|], p(yi|xi) is a directed fully factored model. 


The label bias problem in MEMMs occurs because directed models are locally normalized, 
meaning each CPD sums to 1. By contrast, MRFs and CRFs are globally normalized, which 
means that local factors do not need to sum to 1, since the partition function Z, which sums over all 
joint configurations, will ensure the model defines a valid distribution. 

However, this solution comes at a price: in a CRF, we do not get a valid probability distribution over 
y1-7 until we have seen the whole sentence, since only then can we normalize over all configurations. 
Consequently, CRFs are not as useful as directed probabilistic graphical models (DPGM) for online 
or real-time inference. Furthermore, the fact that Z is a function of all the parameters makes CRFs 
less modular and much slower to train than DPGM’s, as we discuss in Section 4.4.3. 


26 4.5.4 Combining directed and undirected graphs 


We can also define graphical models that contain directed and undirected edges. We discuss a few 
examples below. 


4.5.4.1 Chain graphs 


=£ A chain graph is a PGM which may have both directed and undirected edges, but without any 
= directed cycles. A simple example is shown in Figure 4.41, which defines the following joint model: 


D 
P(£1:D, Y1:D) = P(#1:D)P(Y1:D|#1:D) = 5 JJ valenz) rce) (4.153) 


ing 


In this example, the prior p(x) is specified by a UPGM, and the likelihood p(y|x) is specified as a 
fully factorized DPGM. 
More generally, a chain graph can be defined in terms of a partially directed acyclic graph 


42 (PDAG). This is a graph which can be decomposed into a directed graph of chain components, 
43 where the nodes within each chain component are connected with each other only with undirected 
44 edges. See Figure 4.42 for an example. 


We can use a PDAG to define a joint distribution using [[, p(Ci|pa(Ci)), where each C; is a chain 


46 component, and each CPD is a conditional random field. For example, referring to Figure 4.42, we 
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Figure 4.42: A partially directed acyclic graph (PDAG). The chain components are {A}, {B}, {C,D, E}, 
{F,G}, {H} and {I}. Adapted from Figure 4.15 of [KF09a]. 


®) ) 
et? 


(a) 


Figure 4.43: (a) A DAG with two hidden variables (shaded). (b) The corresponding ADMG. The bidirected 
edges reflect correlation due to the hidden variable. (c) A Markov equivalent ADMG. From Figure 3 of [SGO09]. 
Used with kind permission of Ricardo Silva. 


have 
»(A, B, ..., I) = p(A)p(B)p(C, D, E|A, B)p(F, G|C, D)p(H)p(I|C, E, H) (4.154) 
v(C, D, B\A, B) = FAB oA OB. BOC. D)d(D, E) (4.155) 
p(F,G|C, D) = Fo C)d(G, D)d(F, G) (4.156) 


For more details, see e.g., [KF 09a, Sec 4.6.2]. 


4.5.4.2 Acyclic directed mixed graphs 


One can show [Pea09b, p51] that every latent variable DPGM can be rewritten in a way such that 
every latent variable is a root node with exactly two observed children. This is called the projection 
of the latent variable PGM, and is observationally indistinguishable from the original model. 

Each such latent variable root node induces a dependence between its two children. We can 
represent this with a directed arc. The resulting graph is called an acyclic directed mixed graph 
or ADMG. See Figure 4.43 for an example. (A mixed graph is one with undirected, unidirected, 
and bidirected edges.) 
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Figure 4.44: A VAR(2) process represented as a dynamic chain graph. From [DE00]. Used with kind 
permission of Rainer Dahlhaus. 


One can determine CI properties of ADMGs using a technique called m-separation [Ric03]. 
This is equivalent to d-separation in a graph where every bidirected edge Y; + Y; is replaced by 
Y; + Xij > Yj, where Xj; is a hidden variable for that edge. 

The most common example of ADMGs is when everything is linear-Gaussian. This is known as a 
structural equation model and is discussed in Section 4.7.2. 


4.5.5 Comparing directed and undirected Gaussian PGMs 


In this section, we compare directed and undirected Gaussian graphical models. In Section 4.2.3, 
we saw that directed GGMs correspond to sparse regression matrices. In Section 4.3.5, we saw that 
undirected GGMs correspond to sparse precision matrices. 

The advantage of the DAG formulation is that we can make the regression weights W, and hence 5, 
be conditional on covariate information [Pou04], without worrying about positive definite constraints. 
The disadavantage of the DAG formulation is its dependence on the order, although in certain 
domains, such as time series, there is already a natural ordering of the variables. 

It is actually possible to combine both directed and undirected representations, resulting in a 
model known as a (Gaussian) chain graph. For example, consider a discrete-time, second-order 
Markov chain in which the observations are continuous, 2; € RP. The transition function can be 
represented as a (vector-valued) linear-Gaussian CPD: 


p(@4|%4-1, &t-2, 9) = N(a@,|Aray_1 + A242, X) (4.157) 


43 This is called vector auto-regressive or VAR process of order 2. Such models are widely used in 
44 econometrics for time-series forecasting. 


The time series aspect is most naturally modeled using a DPGM. However, if £! is sparse, then 


46 the correlation amongst the components within a time slice is most naturally modeled using a UPGM. 
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4.5. COMPARING DIRECTED AND UNDIRECTED PGMS 


(a) (b) 


Figure 4.45: (a) A bi-directed graph. (b) The equivalent DAG. Here the z nodes are latent confounders. 
Adapted from Figures 5.12-5.13 of [Cho11]. 


For example, suppose we have 


20 0 0 0 0 -£ 0 0 
0 50 —4 0 00 0 0 0 
Ai=|2 3 2 0 0ļj,A&=ļ0 0 0 0 0 (4.158) 
0 0 0 “2 § 00 4 0 4 
1 
004 0 ¢ 00 0 0 -3 
and 
1 4 4 00 2.13 -147 -12 0 0 
? 1 -; 0 0 —1.47 213 12 0 0 
S=]3 -4 1 0 OF, ES t=] -12 12 18 00 (4.159) 
0 0 0 10 0 0 0O 10 
0 0 0 01 0 0 0 01 


The resulting graphical model is illustrated in Figure 4.44. Zeros in the transition matrices A, and 
A» correspond to absent directed arcs from a;_; and a_2 into a;. Zeros in the precision matrix 
=~! correspond to absent undirected arcs between nodes in £+. 


4.5.5.1 Covariance graphs 


Sometimes we have a sparse covariance matrix rather than a sparse precision matrix. This can 
be represented using a bi-directed graph, where each edge has arrows in both directions, as in 
Figure 4.45(a). Here nodes that are not connected are unconditionally independent. For example 
in Figure 4.45(a) we see that Yı L Y3. In the Gaussian case, this means 413 = U3, = 0. (A 
graph representing a sparse covariance matrix is called a covariance graph, see e.g., [Pen13]). By 
contrast, if this were an undirected model, we would have that Yı L Y3|Y2, and Ai,3 = Asi = 0, 
where A = 57". 

A bidirected graph can be converted to a DAG with latent variables, where each bidirected edge is 
replaced with a hidden variable representing a hidden common cause, or confounder, as illustrated 
in Figure 4.45(b). The relevant CI properties can then be determined using d-separation. 
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Figure 4.46: (a) A simple UPGM. (b) A factor graph representation assuming one potential per maximal 
clique. (c) Same as (b), but graph is visualized differently. (d) A factor graph representation assuming one 
potential per edge. 


1 4.6 PGM extensions 


23 In this section, we discuss some extensions of the basic PGM framework. 


3, 4.6.1 Factor graphs 


33 A factor graph [KFLO1; Loc04] is a graphical representation that unifies directed and undirected 
34 models. They come in two main “flavors”. The original version uses a bipartite graph, where we have 
35 nodes for random variables and nodes for factors, as we discuss in Section 4.6.1.1. An alternative 
36 form, known as a Forney factor graphs |For01] just has nodes for factors, and the variables are 
2° associated with edges, as we explain in Section 4.6.1.2. 


— 4.6.1.1 Bipartite factor graphs 


41 A factor graph is an undirected bipartite graph with two kinds of nodes. Round nodes represent 
42 variables, square nodes represent factors, and there is an edge from each variable to every factor 
43 that mentions it. For example, consider the MRF in Figure 4.46(a). If we assume one potential per 


maximal clique, we get the factor graph in Figure 4.46(b), which represents the function 
f (x1, £2, £3, £4) = fi24(£1, £2, £4) foza(r2, £3, £4) (4.160) 
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(a) (b) 


Figure 4.47: (a) A simple DPGM. (b) Its corresponding factor graph. 


We can represent this in a topologically equivalent way as in Figure 4.46(c). 

One advantage of factor graphs over UPGM diagrams is that they are more fine-grained. For 
example, suppose we associate one potential per edge, rather than per clique. In this case, we get 
the factor graph in Figure 4.46(d), which represents the function 


f (v1, £2, £3, £4) = fia(©1, £4) fi2(£1, £2) f34(£3, £4) fo3(@2, £3) f24(£2, £4) (4.161) 


We can also convert a DPGM to a factor graph: just create one factor per CPD, and connect 
that factor to all the variables that use that CPD. For example, Figure 4.47 represents the following 
factorization: 


f (£1, £2, £3, £4, £5) = fı (£1) f2(£2)f123(£1, £2, £3) fza(x3, £4) f35(£3, £5) (4.162) 


where we define f123(£1, £2, £3) = p(#3|21, £2), etc. If each node has at most one parent (and hence 
the graph is a chain or simple tree), then there will be one factor per edge (root nodes can have their 
prior CPDs absorbed into their children’s factors). Such models are equivalent to pairwise MRFs. 


4.6.1.2 Forney factor graphs 


A Forney factor graph (FFG), also called a normal factor graph, is a graph in which nodes 
represent factors, and edges represent variables [For01; Loe04; Loe+07; CLV19]. This is more similar 
to standard neural network diagrams, and electrical engineering diagrams, where signals (represented 
as electronic pulses, or tensors, or probability distributions) propagate along wires and are modified 
by functions represented as nodes. 

For example, consider the following factorized function: 


F(£1,--.,85) = fal£1)fo(£1, £2) fe(2, £3, £4) fa(@a, £5) (4.163) 


We can visualize this as an FFG as in Figure 4.48a. The edge labeled zg is called a half-edge, since 
it is only connected to one node; this is because x3 only participates in one factor. (Similarly for xs.) 
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Figure 4.49: An FFG with an equality constraint node (left) and its corresponding simplified form (right). 


The directionality associated with the edges is a useful mnemonic device if there is a natural order in 
which the variables are generated. In addition, associating directions with each edge allows us to 


31 uniquely name “messages” that are sent along each edge, which will prove useful when we discuss 


inference algorithms in Section 9.2. 

In addition to being more similar to neural network diagrams, FFGs have the advantage over 
bipartite FGs in that they support hierarchical (compositional) construction, in which complex 
dependency structure between variables can be represented as a blackbox, with the input/output 
interface being represented by edges corresponding to the variables exposed by the blackbox. See 
Figure 4.48b for an example, which represents the function 


F (250424 8a) = fprior(£1, £2, T3, £4) fük(£4, £5) (4.164) 


= The factor fprior represents a (potentially complex) joint distribution p(£1, £2, £3, v4), and the factor 
= fix represents the likelihood term p(x5|x4). Such models are widely used to build error-correcting 
“= codes (see Section 9.3.8). 


To allow for variables to participate in more than 2 factors, equality constraint nodes are introduced, 


= as illustrated in Figure 4.49(a). Formally, this is a factor defined as follows: 


f=(x, 21, £2) = 0(a — #1)6(a — x2) (4.165) 
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where 6(w) is a Dirac delta if u is continuous, and a Kronecker delta if u is discrete. The effect of 
this factor is to ensure all the variables connected to the factor have the same value; intuitively, this 
factor acts like a “wire splitter”. Thus the function represented in Figure 4.49(a) is equivalent to the 
following: 


fam, y2) = fol) fylja (Y1, £) fye (Yo: £) (4.166) 


This simplified form is represented in Figure 4.49(b), where we reuse the x variable across multiple 
edges. We have chosen the edge orientations to reflect our interpretation of the factors f,).(y,2) as 
likelihood terms, p(y|a). We have also chosen to reuse the same fy), factor for both y variables; this 
is an example of parameter tying. 


4.6.2 Probabilistic circuits 


A probabilistic circuit is a kind of graphical model that supports efficient exact inference. It 
includes arithmetic circuits [Dar03; Dar09], sum-product networks (SPNs) [PD11; SCPD22]. 
and other kinds of model. 

Here we briefly describe SPNs. An SPN is a probabilistic model, based on a directed tree-structured 
graph, in which terminal nodes represent univariate probability distributions and non-terminal nodes 
represent convex combinations (weighted sums) and products of probability functions. SPNs are 
similar to deep mixture models, in which we combine together dimensions. SPNs leverage context- 
specific independence to reduce the complexity of exact inference to time that is proportional to the 
number of links in the graph, as opposed to the treewidth of the graph (see Section 9.4.2). 

SPNs are particularly useful for tasks such as missing data imputation of tabular data (see e.g., 
[Cla20; Ver+19]). A recent extension of SPNs, known as einsum networks, is proposed in [Peh+20] 
(see Section 9.6.1 for details on the connection between einstein summation and PGM inference). 


4.6.3 Directed relational PGMs 


A Bayesian network defines a joint probability distribution over a fixed number of random variables. 
By using plate notation (Section 4.2.8), we can define models with certain kinds of repetitive structure, 
and tied parameters, but many models are not expressible in this way. For example, it is not possible 
to represent even a simple HMM using plate notation (see Figure 29.12). Various notational extensions 
of plates have been proposed to handle repeated structure (see e.g., [HMK04; Die10]) but have not 
been widely adopted. The problem becomes worse when we have more complex domains, involving 
multiple objects which interact via multiple relationships.'? Such models are called relational 
probability models or RPMs. In this section, we focus on directed RPMs; see Section 4.6.4 for 
the undirected case. 

As in first order logic, RPMs have constant symbols (representing objects), function symbols 
(mapping one set of constants to another), and predicate symbols (representing relations between 
objects). We will assume that each function has a type signature. To illustrate this, consider an 
example from [RN19, Sec 15.1], which concerns online book reviews on sites such as Amazon. Suppose 


12. See e.g., this blog post from Rob Zinkov: https://www.zinkov.com/posts/2013-07-28-stop-using-plates. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e Jw IN Je 


m. 
m 


IS Ie IR ls lẹ le le Is | 


i) 
© 


N 
= 


IS IS 18 R | 


182 


there are two types of objects, Book and Customer, and the following functions and predicates: 


Honest : Customer > {True, False} (4.167) 

Kindess : Customer —> {1, 2,3, 4,5} (4.168) 

Quality : Book > {1, 2,3, 4,5} (4.169) 

Recommendation : Customer x Book —> {1, 2,3, 4,5} (4.170) 


The constant symbols refer to specific objects. To keep things simple, we assume there are two 
books, Bı and B2, and two customers, Cı and C2. The basic random variables are obtained 
by instantiating each function with each possible combination of objects to create a set of ground 
terms. In this example, these variables are H(C1), Q(B1), R(C1, B2), etc. (We use the abbreviations 
H, K, Q and R for the functions Honest, Kindness, Quality and Recommendation.'’ ) 

We now need to specify the (conditional) distribution over these random variables. We define these 
distributions in terms of the generic indexed form of the variables, rather than the specific ground 
form. For example, we may use the following priors for the root nodes (variables with no parents): 


H(c) ~ Cat(0.99, 0.01) (4.171) 
K(c) ~ Cat(0.1, 0.1, 0.2, 0.3, 0.3) (4.172) 
Q(b) ~ Cat(0.05, 0.2, 0.4, 0.2, 0.15) (4.173) 


For the recommendation nodes, we need to define a conditional distribution of the form 
R(c, b) ~ RecCPD(A(c), K(c), Q(b)) (4.174) 


where RecCPD is the conditional probability distribution (CPD) for the recommendation node. If 
represented as a conditional probability table (CPT), this has 2 x 5 x 5 = 50 rows, each with 5 
entries. This table can encode our assumptions about what kind of ratings a book receives based on 


~~ the quality of the book, but also properties of the reviewer, such as their honest and kindness. (More 


sophisticated models of human raters in the context of crowd-sourced data collection can be found in 
e.g., [LRC19].) 

We can convert the above formulae into a a graphical model “template”, as shown in Figure 4.50a. 
Given a set of objects, we can “unroll” the template to create a “ground network”, as shown in 
Figure 4.50b. There are C x B +2C + B random variables, with a corresonding joint state space 
(set of possible worlds) of size 2°5°+?+"°, which can get quite large. However, if we are only 


-~ interested in answering specific queries, we can dynamically unroll small pieces of the network that 
~_ are relevant to that query [GC90; Bre92]. 


Let us assume that only a subset of the R(c,b) entries are observed, and we would like to 
predict the missing entries of this matrix. This is essentially a simplified recommender system. 
(Unfortunately it ignores key aspects of the problem, such as the content/topic of the books, and 
the interests/preferences of the customers.) We can use standard probabilistic inference methods for 


~~ graphical models (which we discuss in Chapter 9) to solve this problem. 


Things get more interesting when we don’t know which objects are being referred to. For example, 


~ customer C might write a review of a book called “Probabilistic Machine Learning”, but do they 


=2 13. A unary function of an object that returns a basic type, such as Boolean or an integer, is often called an attribute 


46 of that object. 
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) 


(a 


Figure 4.50: RPM for the book review domain. (a) Template for a generic customer Ci and book B; pair. R 
is rating, Q is quality, H is honesty, and K is kindness. (b) Unrolled model for 2 books and 2 customers. 


mean edition 1 (Bı) or edition 2 (Bj)? To handle this kind of relational uncertainty, we can add 
all possible referents as parents to each relation. This is illustrated in Figure 4.51, where now Q(B,) 
and Q(B2) are both parents of R(C1, B1). This is necessary because their review score might either 
depend on Q(B;) or Q(B2), depending on which edition they are writing about. To disambiguate 
this, we create a new variable, L(C;), which specifies which version number of each book customer i 
is referring to. The new CPD for the recommendation node, p(R(c, b)|H(c), K(c),Q(1: B), L(c)), 
has the form 


R(c, b) ~ RecCPT(H(c), K(c), Q(b')) where b = L(c) (4.175) 


This CPD acts like a multiplexer, where the L(c) node specifies which of the parents Q(1 : B) to 
actually use. 

Although the above problem may seem contrived, identity uncertainty is a widespread problem 
in many areas, such as citation analysis, credit card histories, and object tracking (see Section 4.6.5). 
In particular, the problem of entity resolution or record linkage — which refers to the task of 
mapping particular strings (such as names) to particular objects (such as people) — is a whole 
field of research (see e.g., https: //en.wikipedia. org/wiki/Record_linkage for an overview and 
[SHF15] for a Bayesian approach). 


4.6.4 Undirected relational PGMs 


We can create relational UGMs in a manner which is analogous to relational DGMs (Section 4.6.3). 
This is particularly useful in the discriminative setting, for the same reasons that undirected CRFs 
are preferable to conditional DGMs (see Section 4.4). 
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Figure 4.51: An extension of the book review RPM to handle identity uncertainty about which book a given 
customer is actually reviewing. The R(c,b) node now depends on all books, since we don’t know which one is 
being referred to. We can select one of these parents based on the mapping specified by the user’s library, L(c). 


4.6.4.1 Collective classification 


As an example of a relational UGM, suppose we are interested in the problem of classifying web 
pages of a university into types (e.g., student, professor, admin, etc.) Obviously we can do this based 
on the contents of the page (e.g., words, pictures, layout, etc.) However, we might also suppose 
there is information in the hyper-link structure itself. For example, it might be likely for students to 
cite professors, and professors to cite other professors, but there may be no links between admin 
pages and students / professors. When faced with a web page whose label is ambiguous, we can 
bias our estimate based on the estimated labels of its neighbors, as in a CRF. This process is known 


** as collective classification (see e.g., [Sen+08]). To specify the CRF structure for a web-graph of 
22 arbitrary size and shape, we just specify a template graph and potential functions, and then unroll 
23 the template appropriately to match the topology of the web, making use of parameter tying. 


— 4.6.4.2 Markov logic networks 


37 One particularly popular way of specifying relational UGMs is to use first-order logic rather than 


a graphical description of the template. The result is known as a Markov logic network [RD06; 
Dom+06; DLO9]. 
For example, consider the sentences “Smoking causes cancer” and “If two people are friends, and 


41 one smokes, then so does the other”. We can write these sentences in first-order logic as follows: 


Va.Sm(x) => Ca(z) (4.176) 
VaVy.Fr(a,y) ^A Sm(2) => Sm(y) (4.177) 


46 where Sm and Ca are predicates, and Fr is a relation. 
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Fr(A,A) Fr(B,B) Fr(B,A) Fr(A,B) | Sm(A) Sm(B) | Ca(A) Ca(B) 
1 1 0 1 1 1 1 1 
1 1 0 1 1 0 0 0 
1 1 0 1 1 1 0 1 


Table 4.5: Some possible joint instantiations of the 8 variables in the smoking example. 


Friends(A,B) 


~ Smokes 


Friends(B,A) 


Friends(A,A) 


(B) 


Friends(B,B) 
Cancer(B) 


Figure 4.52: An example of a ground Markov logic network represented as a pairwise MRF for 2 people. 
Adapted from Figure 2.1 from [DL09]. Used with kind permission of Pedro Domingos. 


It is convenient to write all formulas in conjunctive normal form (CNF), also known as clausal 
form. In this case, we get 


=Sm(x) V Ca(x) (4.178) 
aFr(a,y) V 7~Sm(x) V Sm(y) (4.179) 


The first clause can be read as “Either x does not smoke or he has cancer”, which is logically equivalent 
to Equation (4.176). (Note that in a clause, any unbound variable, such as x, is assumed to be 
universally quantified.) 

Suppose there are just two objects (people) in the world, Anna and Bob, which we will denote by 
constant symbols A and B. We can then create 8 binary random variables Sm(x), Ca(x), and 
Fr(x,y) for x,y € {A, B}. This defines 28 possible worlds, some of which are shown in Table 4.5.14 

Our goal is to define a probability distribution over these joint assignments. We can do this by 
creating a UGM with these variables, and adding a potential function to capture each logical rule or 
constraint. For example, we can encode the rule =Sm(x) V Ca(x) by creating a potential function 


14. Note that we have not encoded the fact that Fr is a symmetric relation, so Fr(A, B) and Fr(B, A) might have 
different values. Similarly, we have the “degenerate” nodes Fr(A) and Fr(B), since we did not enforce x # y in 
Equation (4.177). (If we add such constraints, then the model compiler, which generates the ground network, should 
avoid creating redundant nodes.) 
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W(Sm(x), Ca(x)), where we define 


1 if nSm(a) V Ca(x2) =T 


0 if ~Sm(x2) V Calz) = F (4.180) 


W(Sm(x), Ca(x)) = 


The result is the UGM in Figure 4.52. 

The above approach will assign non-zero probability to all logically valid worlds. However, logical 
rules may not always be true. For example, smoking does not always cause cancer. We can relax the 
hard constraints by using non-zero potential functions. In particular, we can associate a weight with 
each rule, and thus get potentials such as 


we ias VC =T 
B(Sm(2),Ca(x)) = 3%) E meV Cale) (4.181) 
e” if ASm(a) V Ca(a) = F 
where the value of w > 0 controls strongly we want to enforce the corresponding rule. 
The overall joint distribution has the form 
1 
p(x) = Z(w) exp() wini(x)) (4.182) 


where n;(æ) is the number of instances of clause i which evaluate to true in assignment g. 

Given a grounded MLN model, we can then perform inference using standard methods. Of course, 
the ground models are often extremely large, so more efficient inference methods, which avoid creating 
the full ground model (known as lifted inference), must be used. See [DL09; KNP11] for details. 

One way to gain tractability is to relax the discrete problem to a continuous one. This is the 
basic idea behind hinge-loss MRFs [Bac+15b], which support exact inference using scalable convex 
optimization. There is a template language for this model family known as probabilistic soft logic, 


2 which has a similar “flavor” to MLN, although it is not quite as expressive. 


Recently MLNs have been combined with DL in various ways. For example, [Zha+20f] uses graph 
neural networks for inference. And [WP18] uses MLNs for evidence fusion, where the noisy predictions 


* come from DNNs trained using weak supervision. 


Finally, it is worth noting one subtlety which arises with undirected models, namely that the size 


23 of the unrolled model, which depends on the number of objects in the universe, can affect the results 


of inference, even if we have no data about the new objects. For example, consider an undirected 


22 chain of length T, with T hidden nodes z; and T observed nodes y;; call this model M1. Now suppose 
2° we double the length of the chain to 2T, without adding more evidence; call this model Mə. We find 
37 that p(zilyrr, M1) 4 p(ztlyi-r, M2), for t = 1: T, even though we have not added new information, 
28 due to the different partition functions. This does not happen with a directed chain, because the 
22 newly added nodes can be marginalized out without affecting the original nodes, since the model is 
%2 locally normalized and therefore modular. See [JBB09; Poo+12] for further discussion. 


— 4.6.5 Open-universe probability models 


In Section 4.6.3, we discussed relational probability models, as well as the topic of identity uncertainty. 


45 However, we also implicitly made a closed world assumption, namely that the set of all objects is 
46 fixed and specified ahead of time. In many real world problems, this is an unrealistic assumption. 
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For example, in Section 29.9.3.5, we discuss the problem of tracking an unknown number of objects 
over time. As another example, consider the problem of enforcing the UN Comprehensive Nuclear 
Test Ban Treaty (CTBT). This requires monitoring seismic events, and determinining if they were 
caused by nature or man-made explosions. Thus the number of objects of each type, as well as their 
source, is uncertain [ARS13], 

As another (more peaceful) example, suppose we want to perform citation matching, in which 
we want to know whether to cite an arxiv version of a paper or the version on some conference 
website. Are these the same object? It is often hard to tell, since the titles and author might be the 
same, yet the content may have been updated. It is often necessary to use subtle cues, such as the 
date stored in the meta-data, to infer if the two “textual measurements” refer to the same underlying 
object (paper) or not [Pas+02]. 

In problems such as these, the number of objects of each type, as well as their relationships, is 
uncertain. This requires the use of open-universe probability models or OUPM, which can 
generate new objects as well as their properties [Rus15; MR10; LB19]. The first formal language 
for OUPMs was BLOG [Mil+05], which stands for “Bayesian LOGic”. This used a general purpose, 
but slow, MCMC inference scheme to sample over possible worlds of variable size and shape. 
[Las08; LLC20] describes another open-universe modeling language called multi-entity Bayesian 
networks. 

Very recently, Facebook has released the Bean Machine library, available at https: //beanmachine. 
org/, which supports more efficient inference in OUPMs. Details can be found in [Teh+20], as well 
as their blog post.!° 


4.6.6 Programs as probability models 


OUPMs, discussed in Section 4.6.5, let us define probability models over complex dynamic state 
spaces of unbounded and variable size. The set of possible worlds correspond to objects and their 
attributes and relationships. Another approach is to use a probabilistic programming language 
or PPL, in which we define the set of possible words as the set of execution traces generated by 
the program when it is endowed with a random choice mechanism. (This is a procedural approach 
to the problem, whereas OUPMs are a declarative approach. ) 

The difference between a probabilistic programming language and a standard one was described in 
[Gor+14] as follows: “Probabilistic programs are usual functional or imperative programs with two 
added constructs: (1) the ability to draw values at random from distributions, and (2) the abiliy to 
condition values of variables in a program via observation”. The former is a way to define p(z, y), 
and the latter is the same as standard Bayesian conditioning p(z|y). 

Some recent examples of PPLs include Gen [CT+19], Pyro [Bin+19] and Turing [GXG198]. 
Inference in such models is often based on SMC, which we discuss in Chapter 13. For more details 
on PPLs, see e.g. [Mee+18]. 


4.7 Structural causal models 


While probabilities encode our beliefs about a static world, causality tells us whether and how 
probabilities change when the world changes, be it by intervention or by act of imagination. — 


15. See https://tinyurl.com/2svy5tmh. 
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Judea Pearl [PM18b]. 


In this section, we discuss how we can use directed graphical model notation to represent causal 
models. We discuss causality in greater detail in Chapter 36, but we introduce some basic ideas and 
notation here, since it is foundational material that we will need in other parts of the book. 

The core idea behind causal models is to create a mechanistic model of the world in which we 
can reason about the effects of local changes. The canonical example is an electronic circuit: we 
can predict the effects of any action, such as “knocking out” a particular transistor, or changing the 
resistance level of a wire, by modifying the circuit locally, and then “re-running” it from the same 
initial conditions. 

We can generalize this idea to create a structural causal models or SCM [PGJ16], also called 
functional causal model [Sch19]. An SCM is a triple M = (U,V, F), where U = {U; :i = 1: N} 
is a set of unexplained or exogenous “noise” variables, which are passed as input to the model, 
V = {V; :i = 1 : N} is a set of endogeneous variables that are part of the model itself, and 
F = {fı :i =1 : N} is a set of deterministic functions of the form V; = f;(Vpa,,Ui), where pa; are the 
parents of variable i, and U; € U are the external inputs. We assume the equations can be structured 
in a recursive way, so the dependency graph of nodes given their parents is a DAG. Finally, we 
assume our model is causally sufficient, which means that V and U are all of the causally relevant 
factors (although they may not all be observed). This is called the “causal Markov assumption”. 

Of course, a model typically cannot represent all the variables that might influence observations or 
decisions. After all, models are abstractions of reality. The variables that we choose not to model 
explicitly in a functional way can be lumped into the unmodeled exogenous terms. To represent 
our ignorance about these terms, we can use a distribution p(U) over their values. By “pushing” 
this external noise through the deterministic part of the model, we induce a distribution over the 
endogeneous variables, p(V), as in a probabilistic graphical model. However, SCMs make stronger 
assumptions than PGMs. 

We usually assume p(U) is factorized (i.e., the U; are independent); this is called a Markovian 
SCM. If the exogeneous noise terms are not independent, it would break the assumption that 
outcomes can be determined locally using deterministic functions. If there are believed to be 


31 dependencies between some of the U;, we can add extra hidden parents to represent this; this is often 


depicted as a bidirected or undirected edge connecting the U;, and is known as a semi-Markovian 
SCM. 


= 4.7.1 Example: causal impact of education on wealth 


37 We now give a simple example of an SCM, based on [PM18b, p276]. Suppose we are interested in 
38 the causal effect of education on wealth. Let X represent the level of education of a person (on 
39 some numeric scale, say 0 = high school, 1 = college, 2 = graduate school), and Y represent their 


wealth (at some moment in time). In some cases we might expect that increasing X would increase Y 


41 (although it of course depends on the nature of the degree, the nature of the job, etc). Thus we add 
42 an edge from X to Y. However, getting more education can cost a lot of money (in certain countries), 
43 which is a potentially confounding factor on wealth. Let Z be the debt incurred by a person based 
44 on their education. We add an edge from X to Z to reflect the fact that larger X means larger Z (in 
45 general), and we add an edge from Z to Y to reflect that larger Z means lower Y (in general). 


We can represent our structural assumptions graphically as shown in Figure 4.53b(a). The 
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(a) 


Figure 4.53: (a) PGM for modeling relationship between salary, education and debt. (b) Corresponding SCM. 


corresponding SCM has the form: 


X= fx(Ue) (4.188) 
Z = f2(X,Uz) (4.184) 
Y = fy (X, Z, Uy) (4.185) 


for some set of functions fs, fy, fz, and some prior distribution p(U,,U,,U,). We can also explicitly 
represent the exogeneous noise terms as shown in Figure 4.53b(b); this makes clear our assumption 
that the noise terms are a priori independent. (We return to this point later.) 


4.7.2 Structural equation models 


A structural equation model [Bol89; BP 13], also known as a path diagram, is a special case of 
a structural causal model in which all the functional relationships are linear, and the prior on the 
noise terms is Gaussian. SEMs are widely used in economics and social science, due to the fact that 
they have a causal interpretation, yet they are computationally tractable. 

For example, let us make an SEM version of our education example. We have 


SU. (4.186) 
Fp shige +U; (4.187) 
Y = cy + WayX + WeyZ + Uy (4.188) 


If we assume p(U,) = N(Uz|0, 02), p(Uz) =N(Uz|0, 02), and p(Uy) = N(Uz|0, 07), then the model 
can be converted to the following Gaussian DGM: 


P(X) = N(X|uz, 07) (4.189) 
P(Z|X) = N (Z| cz + W22X, 072) (4.190) 
P(Y |X, Z) = N(Y |cy + WayX + wzyZ, 05) (4.191) 
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(a) (b) 


Figure 4.54: An SCM in which we intervene on Z. (a) Hard intervention, in which we clamp Z and thus cut 
its incoming edges (shown as dotted). (b) Soft intervention, in which we change Z’s mechanism. The square 
node is an “action” node, using the influence diagram notation from Section 34.2. 


We can relax the linearity assumption, to allow arbitrarily flexible functions, and relax the Gaussian 
assumption, to allow any noise distribution. The resulting “nonparametric SEMs’ are equivalent to 
structural causal models. (For a more detailed comparison between SEMs and SCMs, see [Peal2; 
BP13; Shi00b].) 


4.7.3 Do operator and augmented DAGs 


One of the main advantages of SCMs is that they let us predict the effect of interventions, which 


28 are actions that change one or more local mechanisms. A simple intervention is to force a variable to 


have a given value, e.g., we can force a gene to be “on” or “off”. This is called a perfect intervention 
and is written as do(X; = x;), where we have introduced new notation for the “do” operator (as 


31 in the verb “to do”). This notation means we actively clamp variable X; to value x; (as opposed to 
32 just observing that it has this value). Since the value of X; is now independent of its usual parents, 
33 we should “cut” the incoming edges to node X; in the graph. This is called the “graph surgery” 


operation. 
In Figure 4.54a we illustrate this for our education SCM, where we force Z to have a given value. 
36 For example, we may set Z = 0, by paying off everyone’s student debt. Note that p(X|do(Z = z)) # 


37 p(X|Z = z), since the intervention changes the model. For example, if we see someone with a debt of 
38 0, we may infer that they probably did not get higher education, i.e., p(X > 1|Z = 0) is small; but if 
39 we pay off everyone’s college loans, then observing someone with no debt in this modified world should 


not change our beliefs about whether they got higher education, i.e., p(X > 1|do(Z = 0)) = p(X > 1). 
In more realistic scenarios, we may not be able to set a variable to a specific value, but we may 


42 be able to change it from its current value in some way. For example, we may be able to reduce 
43 everyone’s debt by some fixed amount, say A = —10,000. Thus we replace Z = fz(X,Uz~) with 
44 Z = f!(Z,U.), where f!(Z,U.) = f.(Z,Uz) + A. This is called an additive intervention. 


To model this kind of scenario, we can add create an augmented DAG, in which every variable is 


46 augmented with an additional parent node, representing whether or not the variable’s mechanism is 
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Level Activity Questions Examples 

1:Association. Seeing How would seeing A Someone took aspirin, how 

p(Y |a) change my belief in Y? likely is it their headache will 
be cured? 

2:Intervention. Doing What if I do A? If I take aspirin, will my 

p(Y |do(a)) headache be cured? 

3:Counterfactuals. Imagining Was it A that caused Would my headache be cured 

p(Y“|do(a’), y’) Y? had I not taken aspirin? 


Table 4.6: Pearl’s causal hierarchy. Adapted from Table 1 of [Pea19]. 


changed in some way [Daw02; Daw15; CPD17]. These extra variables are represented by square nodes, 
and correspond to decision variables or actions, as in the influence diagram formalism (Section 34.2). 
The same formalism is used in MDPs for reinforcement learning (see Section 34.5). 

We give an example of this in Figure 4.54b, where we add the A, € {0,1} node to specify whether 
we use the debt reduction policy or not. The modified mechanism for Z becomes 


fz(X, Uz) if A, =0 


4.192 
fa(X,Uz)+A if A,=1 ( ) 


Z = f(X, Uz, Az) = i 

With this new definition, conditioning on the effects of an action can be performed using standard 
probabilistic inference. That is, p(Q|do(A, = a), E = e) = p(Q|A, = a, E = e), where Q is the query 
(e.g., the event X > 1) and E are the (possibly empty) evidence variables. This is because the A, 
node has no parents, so it has no incoming edges to cut when we clamp it. 

Although the augmented DAG allows us to use standard notation (no explicit do operators) and 
inference machinery, the use of “surgical” interventions, which delete incoming edges to a node that 
is set to a value, results in a simpler graph, which can simplify many calculations, particularly in the 
non-parametric setting (see |Pea09b, p361] for a discussion). It is therefore a useful abstraction, even 
if it is less general than the augmented DAG approach. 


4.7.4 Counterfactuals 


So far we have been focused on predicting the effects of causes, so we can choose the optimal action 
(e.g., if I have a headache, I have to decide should I take an aspirin or not). This can be tackled 
using standard techniques from Bayesian decision theory, as we have seen (see [Daw00; Daw15; LR19; 
Roh21; DM22] for more details). 

Now suppose we are interested in the causes of effects. For example, suppose I took the aspirin 
and my headache did go away. I might be interested in the counterfactual question “if I had not 
taken the aspirin, would my headache have gone away anyway?”. This kind of reasoning is crucial for 
legal reasoning (see e.g., [DMM17]), as well as for tasks like explainability and fairness. 

Counterfactual reasoning requires strictly more assumptions than reasoning about interventions 
(see e.g., [DM22]). Indeed, Judea Pearl has proposed what he calls the causal hierarchy [Pea09b; 
PGJ16; PM18b], which has three levels of analysis, each more powerful than the last, but each 
making stronger assumptions. See Table 4.6 for a summary. 

In counterfactual reasoning, we want to answer questions of the type p(Y“ |do(a), y), which is read 
as: “what is the probability distribution over outcomes Y if I were to do a’, given that I have already 
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Figure 4.55: Illustration of the potential outcomes framework as a SCM. The nodes with dashed edges are 
unobserved. In this example, for unit 1, we select action Ai = 0 and observe Yı = Y = yı, whereas for unit 
2, we select action Az = 1 and observe Y> = Y} = y2. 


done a and observed outcome y”. (We can also condition on any other evidencee that was observed, 
such as covariates æ.) The quantity Y is often called a potential outcome [Rub74], since it is 
the outcome that would occur in a hypothetical world in which you did a’ instead of a. (Note that 
p(¥* = y) is equivalent to p(Y = y|do(a’)), and is an interventional prediction, not a counterfactual 
one.) 

The assumptions behind the potential outcomes framework can be clearly expressed using a 


26 structural causal model. We illustrate this in Figure 4.55 for a simple case where there are two 


possible actions. We see that we have a set of “units”, such as individual patients, indexed by 


28 subscripts. Each unit is associated with a hidden exogeneous random noise source, U;, that captures 


everything that is unique about that unit. This noise gets deterministically mapped to two potential 
outcomes, Y? and Y,', depending on which action is taken. For any given unit, we only get to observe 


31 one of the outcomes, namely the one corresponding to the action that was actually chosen. In the 
32 figure, for unit 1, we chose action A; = 0, so we get to see Y? = y1, whereas for unit 2, we chose 
33 action Ay = 1, so we get to see Y} = y2. The fact that we cannot simultaneously see both outcomes 


for the same unit is called the “fundamental problem of causal inference” [Hol86]. 
We will assume the noise sources are independent, which is known as the “stable unit treatment 


36 value assumption” or SUTVA. (This would not be true if the treatment on person j could somehow 
37 affect the outcome of person i, e.g., due to spreading disease or information between i and j.) We 
38 also assume that the determinsistic mechanisms that map noise to outcomes are the same across 


all units (represented by the shared parameter vector 0 in Figure 4.55). We need to make one final 
assumption, namely that the exogeneous noise is not affected by our actions. (This is a formalization 


41 of the assumption known as “all else being equal”, or (in legal terms) “ceteris paribus”.) 


With the above assumptions, we can predict what the outcome for an individual unit would have 


43 been in the alternative universe where we picked the other action. The procedure is as follows. First 
44 we perform abduction using SCM G, to infer p(U;|A; = a, Y; = yi), which is the posterior over 
45 the latent factors for unit 7 given the observed evidence in the actual world. Second we perform 
46 intervention, in which we modify the causal mechanisms of G by replacing A; = a with A; =a’ to 
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4.7. STRUCTURAL CAUSAL MODELS 


get Gar. Third we perform prediction, in which we propagate the distribution of the latent factors, 
p(U;|A; = a, Yi = yi), through the modified SCM Gw to get pY% |A =q, Y; = yi). 

In Figure 4.55, we see that we have two copies of every possible outcome variable, to represent 
the set of possible worlds. Of course, we only get to see one such world, based on the actions that 
we actually took. More generally, a model in which we “clone” all the deterministic variables, with 
the noise being held constant between the two branches of the graph for the same unit, is called a 
twin network [Pea09b]. We will see a more practical example in Section 29.12.6, where we discuss 
assessing the counterfactual causal impact of an intervention in a time series. (See also [RR11; RR13], 
who propose a related formalism known as single world intervention graph or SWIG.) 

We see from the above that the potential outcomes framework is mathematically equivalent to 
structural causal models, but does not use graphical model notation. This has led to heated debate 
between the founders of the two schools of thought.'°. The SCM approach is more popular in 
computer science (see e.g., [PJS17; Sch19; Sch+21b]), and the PO approach is more popular in 
economics (see e.g. [AP09; Imb19]). Modern textbooks on causality usually use both formalisms (see 
e.g., [HR20a; Nea20]). 


16. The potential outcomes framework is based on the work of Donald Rubin, and others, and is therefore sometimes 
called the Rubin Causal Model (see e.g., https: //en.wikipedia.org/wiki/Rubin_causal_model). The structural 
causal models framework is based on the work of Judea Pearl and others. See e.g., http: //causality.cs.ucla.edu/ 
blog/index. php/2012/12/03/judea-pearl-on-potential-outcomes/ for a discussion of the two. 
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5 Information theory 


Machine learning is fundamentally about information processing. But what do we mean by 
“information”? We discuss this in Section 5.1—-Section 5.3. We then go on to briefly discuss two main 
applications of information theory. The first application is data compression or source coding, 
which is the problem of removing redundancy from data so it can be represented more compactly, 
either in a lossless way (e.g., ZIP files) or a lossy way (e.g., MP3 files). See Section 5.4 for details. 
The second application is error correction or channel coding, which means encoding data in 
such a way that it is robust to errors when sent over a noisy channel, such as a telephone line or a 
satellite link. See Section 5.5 for details. 

It turns out that methods for data compression and error correction both rely on having an accurate 
probabilistic model of the data. For compression, a probabilistic model is needed so the sender can 
assign shorter codewords to data vectors which occur most often, and hence save space. For error 
correction, a probabilistic model is needed so the receiver can infer the most likely source message by 
combining the received noisy message with a prior over possible messages. 

It is clear that probabilistic machine learning is useful for information theory. However, information 
theory is also useful for machine learning. Indeed, we have seen that Bayesian machine learning is 
about representing and reducing our uncertainty, and so is fundamentally about information. In 
Section 5.6.2, we explore this direction in more detail, where we discuss the information bottleneck. 

For more information on information theory, see e.g., [Mac03; CT06]. 


5.1 KL divergence 


This section is written with Alex Alemi. 


To discuss information theory, we need some way to measure or quantify information itself. Let’s 
say we start with some distribution describing our degrees of belief about a random variable, call it 
q(x). We then want to update our degrees of belief to some new distribution p(x), perhaps because 
we've taken some new measurements or merely thought about the problem a bit longer. What we 
seek is a mathematical way to quantify the magnitude of this update, which we’ll denote I[p||q]. 
What sort of criteria would be reasonable for such a measure? We discuss this issue below, and then 
define a quantity that satisfies these criteria. 
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5.1.1 Desiderata 

For simplicity, imagine we are describing a distribution over N possible events. In this case, the 
probability distribution g(a) consists of N non-negative real numbers that add up to 1. To be even 
more concrete, imagine we are describing the random variable representing the suit of the next card 
we'll draw from a deck: S € {&, A, Q, ©}. Imagine we initially believe the distributions over suits to 
be uniform: q = l, i i i), If our friend told us they removed all of the red cards we could update 
to: d = [5 $, 0,0]. Alternatively, we might believe some diamonds changed into clubs and want to 
update to q” = [z, 2, 2, al. Is there a good way to quantify how much we’ve updated our beliefs? 


Which is a larger update: q > q' or q > q"? 
It seems desireable that any useful such measure would satisfy the following properties: 


1. continuous in its arguments: If we slightly perturb either our starting or ending distribution, 
it should similarly have a small effect on the magnitude of the update. For example: I[p||4 + 

€, L, + ; — ¢| should be close to I[p||q] for small €, where q = l, + + 1. 

2. non-negative: I[p||q| > 0 for all p(x) and q(x). The magnitude of our updates are non-negative. 


3. permutation invariant: The magnitude of the update should not depend on the order we choose for 
the elements of x. For example, it shouldn’t matter if I list my probabilities for the suits of cards 
in the order &,@,U, or &, >, Q, @, if I keep the order consistent across all of the distributions, 
I should get the same answer. For example: I[a, b,c, dlle, f, g, h] = I[a, d,c, blle, h, g, f]. 


4. monotonic for uniform distributions: While it’s hard to say how large the updates in our beliefs 
are in general, there are some special cases for which we have a strong intuition. If our beliefs 
update from a uniform distribution on N elements to one that is uniform in N’ elements, the 
information gain should be an increasing function of N and a decreasing function of N’. For 
instance changing from a uniform distribution on all four suits [4, 4, 4, +] (so N = 4) to only one 
suit, such as all clubs, [1,0,0,0] where N’ = 1, is a larger update than if I only updated to the 
card being black, [5, 4,0,0] where N’ = 2. 

5. satisfy a natural chain rule: So far we’ve been describing our beliefs in what will happen on the next 
card draw as a single random variable representing the suit of the next card (S € {&,@,9,0}). 
We could equivalently describe the same physical process in two steps. First we consider the 
random variable representing the color of the card (C € {l, O}), which could be either black 
(Œ = {&, @}) or red (O = {9, >}). Then, if we draw a red card we describe our belief that it is 9 
versus %. If it was instead black we would assign beliefs to it being & versus @. We can convert 
any distribution over the four suits into this conditional factorization, for example: 


o(s)= [3.5.2.5] (641) 
becomes 
o= S| ptaayio=m=[2.2| ooe li], (5.2) 


In the same way we could decompose our uniform distribution q. Obviously, for our measure of 
information to be of use the magnitude of the update needs to be the same regardless of how we 
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5.1. KL DIVERGENCE 


choose to describe what is ultimately the same physical process. What we need is some way to 
relate what would be four different invocations of our information function: 


Ts = 1 [p(S)\|q(S)] (5.3) 
Io = T[p(C)|la(C)| (5.4) 
In = =I pts, AC = B)|la({%, A} |C =H) (5.5) 
In = 1 PHY, OFC =O)llad9, OFC =O)]. (5.6) 


Clearly Is should be some function of {Ic, Ig, Ig}. Our last desiderata is that the way we measure 
the magnitude of our updates will have Ig be a linear combination of Ic, Ig, Ip. In particular, 
we will require that they combine as a weighted linear combinations, with weights set by the 
probability that we would find ourselves in that branch according to the distribution p: 


5 3 
Is = Ic +p(C =H) + p(C = OM =lo+ slat <h (5.7) 


Stating this requirement more generally: If we partition æ into two pieces [x,, xR], so that we 
can write p(x) = p(x )p(£r|æz) and similarly for q, the magnitude of the update should be 


T|p(x)||q(#)] = I[p(aer)ila@ex)| + Ep.) U[p(@ rer) |la(@er|er)|]. (5.8) 


Notice that this requirement breaks the symmetry between our two distributions: The right hand 
side asks us to take the expected conditional information gain with respect to the marginal, but 
we need to decide which of two marginals to take the expectation with respect to. 


5.1.2 The KL divergence uniquely satisfies the desiderata 


We will now define a quantity that is the only measure (up to a multiplicative constant) that satisfies 
the above desiderata. The Kullback-Leibler divergence or KL divergence, also known as the 
information gain or relative entropy, is defined as follows: 


Dx ( (p | q) tSn log aw (5.9) 


This naturally extends to continuous distributions: 
p(x) 

q(x) 

Next we will verify that this definition satisfies all of our desiderata. (The proof that it is the unique 
measure which captures these properties can be found in e.g., [Hob69; Rén61].) 


Dea. (p || a) ê I dz plx) log (5.10) 


5.1.2.1 Continuity of KL 


One of our desiderata was that our measure of information gain should be continuous. The KL 
divergence is manifestly continuous in its arguments except potentially when pk or qk is zero. In the 
first case, notice that the limit as p —> 0 is well behaved: 


lim plog Ê = 0. (5.11) 
p—0 q 
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Taking this as the definition of the value of the integrand when p = 0 will make it continuous there. 
Notice that we do have a problem however if q = 0 in some place that p 4 0. Our information 
gain requires that our original distribution of beliefs g has some support everywhere the updated 
distribution does. Intuitively it would require an infinite amount of information for us to update our 
beliefs in some outcome to change from being exactly 0 to some positive value. 


5.1.2.2 Non-negativity of KL divergence 


In this section, we prove that the KL divergence as defined is always non-negative. We will make use 
of Jensen’s inequality, which states that for any convex function f, we have that 


f (>: vi < XDA (ai) (5.12) 
where A; > 0 and Lee A; = 1. This can be proved by induction, where the base case with n = 2 
follows by definition of convexity. 

Theorem 5.1.1. (Information inequality) Dga (p || q) > 0 with equality iff p = q. 


Proof. We now prove the theorem, following [CT06, p28]. As we noted in the previous section, the 
KL divergence requires special consideration when p(x) or q(x) = 0, the same is true here. Let 


A = {x : p(x) > 0} be the support of p(x). Using the convexity of the log function and Jensen’s 
inequality, we have that 
q(x 
-Da pld = — So re jlog 2E? = Y ple) log 22} (5.13) 
zEA D LEA p(z 
q(x 
< log y` aot = log ae) (5.14) 
zEA Pak zEA 
< log 5 q(x) = log 1 = (5.15) 
LEX 


= Since log(x) is a strictly concave function (— log(x) is convex), we have equality in Equation (5.14) iff 


p(x) = cq(x) for some c that tracks the fraction of the whole space ¥ contained in A. We have equality 


3 in Equation (5.15) iff 1 e4 q(x) = Z pex a(x) = 1, which implies c = 1. Hence Dga (p || q) = 0 iff 
= p(x) = q(x) for all z. 


The non-negativity of KL divergence often feels as though it’s one of the most useful results in 


39 Information Theory. It is a good result to keep in your back pocket. Anytime you can rearrange an 
40 expression in terms of KL divergence terms, since those are guaranteed to be non-negative, dropping 
41 them immediately generates a bound. 


— 5.1.2.3 KL divergence is invariant to reparameterizations 


45 We wanted our measure of information to be invariant to permutations of the labels. The discrete 
46 form is manifestly permutation invariant as summations are. The KL divergence actually satisfies a 
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5.1. KL DIVERGENCE 


much stronger property of reparameterization invariance. Namely, we can transform our random 
variable through an arbitrary invertible map and it won’t change the value of the KL divergence. 

If we transform our random variable from x to some y = f(x) we know that p(x) dx = p(y) dy and 
q(x) dx = q(y) dy. Hence the KL divergence remains the same for both random variables: 


= Dx (p(y) || a(y))- (5-16) 


dy 
dx 
dy 
dx 


Ko 0) 
Dia, (pa) || a(0)) = f depa) tE = f ay p(y) oe 
qy 


Because of this reparameterization invariance we can rest assured that when we measure the KL 
divergence between two distributions we are measuring something about the distributions and not the 
way we choose to represent the space in which they are defined. We are therefore free to transform 
our data into a convenient basis of our choosing, such as a Fourier bases for images, without affecting 
the result. 


5.1.2.4 Maontonicity for uniform distributions 


Consider updating a probability distribution from a uniform distribution on N elements to a uniform 
distribution on N’ elements. The KL divergence is: 


1 sy N 
Dut (p || 2) = D_ aq og AL = log a7 (5.17) 
k N 


or the log of the ratio of the elements before and after the update. This satisfies our monotonocity 
requirement. 

We can interpret this result as follows: Consider finding an element of a sorted array by means of 
bisection. A well designed yes/no question can cut the search space in half. Measured in bits, the 
KL divergence tells us how many well designed yes/no questions are required on average to move 
from q to p. 


5.1.2.5 Chain rule for KL divergence 


Here we show that the KL divergence satisfies a natural chain rule: 


Dra. (o(e,y) || af, y)) = f de duple, y) tog PEH (5.18) 
= fu dy p(x, y) og n + log ae (5.19) 


= Dri (p(z) || q(@)) + Epc) [Dx (pyle) || ayl))I. (5.20) 


We can rest assured that we can decompose our distributions into their conditionals and the KL 
divergences will just add. 

As a notational convenience, the conditional KL divergence is defined to be the expected value 
of the KL divergence between two conditional distributions: 


Dia (plula) tule) È f dela) f duplule) tog BE). (5.21) 


This allows us to drop many expectation symbols. 
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5.1.3 Thinking about KL 


In this section, we discuss some qualitative properties of the KL divergence. 


5.1.3.1 Units of KL 


Above we said that the desiderata we listed determined the KL divergence up to a multiplicative 
constant. Because the KL divergence is logarithmic, and logarithms in different bases are the same 
up to a multiplicative constant, our choice of the base of the logarithm when we compute the KL 
divergence is a choice akin to choosing which units to measure the information in. 

If the KL divergence is measured with the base-2 logarithm, it is said to have units of bits, short 
for “binary digits”. If measured using the natural logarithm as we normally do for mathematical 


convenience, it is said to be measured in nats for “natural units”. 
To convert between the systems, we use logy y = ae Hence 
1 bit = log2 nats ~ 0.693 nats (5.22) 
1 
1 nat = —— bits ~ 1.44 bits. (5.23) 
log 2 


5.1.3.2 Asymmetry of the KL divergence 


The KL divergence is not symmetric in its two arguments. While many find this asymmetry confusing 
at first, we can see that the asymmetry stems from our requirement that we have a natural chain 
rule. When we decompose the distribution into its conditional, we need to take an expectation with 
respect to the variables being conditioned on. In the KL divergence we take this expectation with 
respect to the first argument p(x). This breaks the symmetry between the two distributions. 

At a more intuitive level, we can see that the information required to move from q to p is in general 
different than the information required to move from p to q. For example, consider the KL divergence 
between two Bernoulli distributions, the first with the probability of success given by 0.443 and the 
second with 0.975: 


0.975 0.025 , 
Dg = 0.975 log 0443 + 0.025 log 057 = 0.692 nats ~ 1.0 bits. (5.24) 
2, Soit takes 1 bit of information to update from a [0.443, 0.557] distribution to a [0.975, 0.025] Bernoulli 
<- distribution. What about the reverse? 
0.443 0.557 
Dg = 0.443 log —— ; log —— = 1. ts ~ 2.0 bit .2 
KL = 0.443 log 0.975 + 0.557 log 0.035 38 nats 0 bits, (5.25) 


= so it takes two bits, or twice as much information to move the other way. Thus we see that starting 
= with a distribution that is nearly even and moving to one that is nearly certain takes about 1 bit of 
= information, or one well designed yes/no question. To instead move us from near certainty in an 
= outcome to something that is akin to the flip of a coin requires more persuasion. 


— 5.1.3.3 Minimizing forwards vs reverse KL 


45 The asymmetry of KL means that finding a p that is close to q by minimizing Dx (p || q) (also called 
46 the inclusive KL) gives different behavior than minimizing Dx (q || p) (also called the exclusive 
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N — 
T 0.2 KA . —= min KL{q; p] 
x . — min KL[p; 4] 
0.0 
—5 0 5 10 15 20 
x 


Figure 5.1: Demonstration of the mode-covering or mode-seeking behavior of KL divergence. The original 
distribution q is bimodal. When we minimize Dxx (q || p), then p covers the modes of q (orange). When we min- 
imize Dri (p || q), then p ignores some of the modes of q (green) Generated by minimize_kl_ divergence.ipynb. 


Figure 5.2: Illustrating forwards vs reverse KL on a symmetric Gaussian. The blue curves are the con- 
tours of the true distribution p. The red curves are the contours of a factorized approximation q. (a) 
Minimizing Dru (p || q). (b) Minimizing Dxı (q || p). Adapted from Figure 10.2 of [Bis06]. Generated by 
kl_pq_ gauss.ipynb. 


KL). For example, consider the bimodal distribution q shown in blue in Figure 5.1, which we 
approximate with a unimodal Gaussian. To prevent Dx (q || p) from becoming infinite, we must 
have p > 0 whenever q > 0 (i.e., p must have support everywhere q does), so p tends to cover both 
modes as it must be nonvanishing everywhere q is; this is called mode-covering or zero-avoiding 
behavior (orange curve). 

By contrast, to prevent Dx (p || q) from becoming infinite, we must have p = 0 whenever q = 0, 
which creates mode-seeking or zero-forcing behavior (green curve). 


5.1.3.4 Moment projection 


Suppose we compute q by minimizing the forwards KL: 


q = argmin Dpi (p || q) (5.26) 
q 


This is called M-projection, or moment projection since the optimal q matches the moments of 
p; this is called moment matching. 
To see why, let us assume that q is an exponential family distribution of the form 


g(a) = h(x) exp[n" T (æ) — log Z(n)] (5.27) 
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where 7 (a) is the vector of sufficient statistics, and 7 are the natural parameters. The first order 


optimality conditions are as follows: 


On Dra (P || 4) = n, f(a) log ale) 


= -ðn | v(x) 08 ( 


h(a) exp[n"T (a) — log Z(n)]) 


= ðh; / p(w) (n"T (a) — log Z(n)) 


=- | vayri(e) + 


Va(w) [Ti (x) 


= Up (a) [Ti(a)] + 


(5.28) 
(5.29) 
(5.30) 


(5.31) 
(5.32) 


where in the penultimate line we used the fact that the derivative of the log partition function 
yields the expected sufficient statistics, as shown in Equation (2.199). Hence the expected sufficient 
statistics (moments of the distribution) must match. 

As an example, suppose the true target distribution p is a correlated 2d Gaussian, p(x) = 


N (zlu, ©) = N (x|, A~*), where 


by Xi Lig 
= , = 
ma aa 


Ay a 
A = 
(ie A22 


(5.33) 


We will approximate this with a distribution q which is a product of two 1d Gaussians, i.e., a Gaussian 


with a diagonal covariance matrix: 


q(æ|m, V) = N (z1|M1, v1)N (z2|M2, v2) 


If we perform moment matching, the optimal q must therefore have the following form: 


g(a) = N (z1|u1, £11) N (x2| H2, E22) 


(5.34) 


(5.35) 


In Figure 5.2(a), we show the resulting distribution. We see that q covers (includes) p, but its support 


is too broad (under-confidence). 


5.1.3.5 Information projection 


Now suppose we compute q by minimizing the reverse KL: 


q = argmin Dg (q || p) 
q 


This is called I-projection, or information projection. 
In the case of a 2d Gaussian, one can show [Mac03, p435] that the optimal solution has the form 


q(x) = N (z1|m1, Ar )N (£2|M2, Ady ) 


my = pı — Aq Aia(me — u2) 
mə = uz — A33 A21 (m1 — p1) 
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5.1. KL DIVERGENCE 


To solve this fixed point equation, we can set Mm; = Hi, so we have again captured the mean exactly. 
However, the posterior variance is too narrow, i..e, the approximate posterior is overconfident. This 
is shown in Figure 5.2(b). (Note, however, that minimizing the reverse KL does not always result in 
an overly compact approximation, as explained in [Tur+08].) 


5.1.3.6 KL as expected weight of evidence 


Imagine you have two different hypotheses you wish to select between, which we’ll label P and Q. 
You collect some data D. Bayes’ rule tells us how to update our beliefs in the hypotheses being 
correct: 

Pr(D|P) 


Pr(PID) = Sa 


Pr(P). (5.40) 


Normally this requires being able to evaluate the marginal likelihood Pr(D), which is difficult. If we 
instead consider the ratio of the probabilities for the two hypotheses: 
Pr(P|D) _ Pr(D|P) Pr(P) 
Pr(Q|D)  Pr(D|Q) Pr(Q)’ 
the marginal likelihood drops out. Taking the logarithm of both sides, and identifying the probability 
of the data under the model as the likelihood we find: 
Pr(PID) _, P(D) | 4, Pr(P) 
Pr(Q|D) °° q(D) E PrO) 
The posterior log probability ratio for one hypothesis over the other is just our prior log probability 
ratio plus a term that I. J. Good called the weight of evidence [Goo85| D for hypothesis P over 


Q: 


. A p(D) 
w[P/Q; D] = log IDY (5.43) 


(5.41) 


be 


(5.42) 


With this interpretation, the KL divergence is the expected weight of evidence for P over Q given 
by each observation, provided P were correct. Thus we see that data will (on average) add rather 
than subtract evidence towards the correct hypothesis, since KL divergence is always non-negative in 
expectation (see Section 5.1.2.2). 


5.1.4 Properties of KL 


Below are some other useful properties of the KL divergence. 


5.1.4.1 Compression Lemma 


An important general purpose result for the KL divergence is the Compression Lemma: 


Theorem 5.1.2. For any distributions P and Q with a well-defined KL divergence, and for any 
scalar function ġo defined on the domain of the distributions we have that: 


ip [4] < logEg [e*] + Dx (P || Q). (5.44) 
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Proof. We know that the KL divergence between any two distributions is non-negative. Consider a 
distribution of the form: 


g(x) = G2) ogla), (5.45) 
where the partition function is given by: 
Z= fu q(x)e?™, (5.46) 


Taking the KL divergence between p(x) and g(x) and rearranging gives the bound: 


Dut (P || G) = Dux (P || Q) — Ep [e(@)] + log(Z) > 0. (5.47) 


One way to view the compression lemma is that it provides what is termed the Donsker-Varadhan 
variational representation of the KL divergence: 


Dra (P || Q) = sup Ep [6(0)] — log Eg [t0]. (5.48) 


In the space of all possible functions @ defined on the same domain as the distributions, assuming all 
of the values above are finite, the KL divergence is the supremum achieved. For any fixed function 
¢(a), the right hand side provides a lower bound on the true KL divergence. 

Another use of the compression lemma is that it provides a way to estimate the expectation of 
some function with respect to an unknown distribution P. In this spirit, the Compression Lemma 
can be used to power a set of what are known as PAC Bayes bounds of losses with respect to the 
true distribution in terms of measured losses with respect to a finite training set. See for example 


39 Section 17.4.5 or Banerjee [Ban06]. 


= 5.1.4.2 Data processing inequality for KL 


36 We now show that any processing we do on samples from two different distributions makes their 


samples approach one another. This is called the data processing inequality, since it shows that 


38 we cannot increase the information gain from q to p by processing our data and then measuring it. 


Theorem 5.1.3. Consider two different distributions p(x) and q(x) combined with a probabilistic 


a1 channel t(y|x). If p(y) is the distribution that results from sending samples from p(x) through the 
42 channel t(y|x) and similarly for q(y) we have that: 


Dri (p(x) || ¢(x)) > Dex (p(y) I| a(y)) (5.49) 
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Proof. The proof uses Jensen’s inequality from Section 5.1.2.2 again. Call p(x, y) = p(x)t(y|x) and 


g(x,y) = q(x)t(y|z). 


Dex, P) | a2) = f dea) log Me) (5.50) 
safes ele tow PAte) 
= [ax | ayp(\e(y\2) og PEER (5.51) 
= p(x, y) 
= fa J iuw) log Ga) (5.52) 
- f dyrt) f de plal) 08 se (5.53) 
- f dyptu)roe ( f apte $4) (5.54) 
= - | ww jog (4 Ue f dwala) (5.55) 
=e dy ply) low l = Dex. (p(s) || aC) (5.56) 
a(y) 


One way to interpret this result is that any processing done to random samples makes it harder to 
tell two distributions apart. 
As a special form of processing, we can simply marginalize out a subset of random variables. 


Corollary 5.1.1. (Monotonicity of KL divergence) 


Dx (p(z,y) || a(x, y)) 2 Dex (Wie) || a(z)) (5.57) 
Proof. The proof is essentially the same as the one above. 
Daa. (plev) (ey) = f de f dupie, y) tog E (5.58) 
aly) a(xly) 
=- f avvto) f depli on (Tey) d 
- [wvu )tog (20 } f ae q(aly) ) (5.60) 
= f dup) iog”? = Dex (oly) | a0) (5.61) 
a(y) 
(5.62) 


One intuitive interpretation of this result is that if you only partially observe random variables, it 
is harder to distinguish between two candidate distributions than if you observed all of them. 
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5.1.5 KL divergence and MLE 


Suppose we want to find the distribution q that is as close as possible to p, as measured by KL 
divergence: 


q = arg min Dx (p || q) = arg min | pC) log p(a)da — fræ log q(x)dax (5.63) 


Now suppose p is the empirical distribution, which puts a probability atom on the observed training 
data and zero mass everywhere else: 


ies 
r) = — ÔL — Tn 5.64 
pote) = yz lee) (5.64) 
Using the sifting property of delta functions we get 
Dia. (pp | a) =- f pola) log qle)de + C (5.65) 
1 
=— — d(x — £n) | log q(x)dx + C 5.66 
JE )| ioga) (5.66) 
1 
= -Np 2 galen) +C (5.67) 
where C = f pp(x)logpp(zx) is a constant independent of q. 
We can rewrite the above as follows 
Dut (pp || q) = Hee (pp, q) — H(pp) (5.68) 
28 where 
Hee(p, q) = — X pr log a (5.69) 
k 


32 is known as the cross entropy. The quantity Hee(pp,q) is the average negative log likelihood 
32 of q evaluated on the training set. Thus we see that minimizing KL divergence to the empirical 
34 distribution is equivalent to maximizing likelihood. 


This perspective points out the flaw with likelihood-based training, namely that it puts too 


°° much weight on the training set. In most applications, we do not really believe that the empirical 
21 distribution is a good representation of the true distribution, since it just puts “spikes” on a finite 
38 set of points, and zero density everywhere else. Even if the dataset is large (say 1M images), the 
39 universe from which the data is sampled is usually even larger (e.g., the set of “all natural images” 
20 is much larger than 1M). Thus we need to somehow smooth the empirical distribution by sharing 
* probability mass between “similar” inputs. 


~ 5.1.6 KL divergence and Bayesian Inference 


45 Bayesian inference itself can be motivated as the solution to a particular minimization problem of 
46 KL. 
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5.1. KL DIVERGENCE 


Consider a prior set of beliefs described by a joint distribution q(0, D) = q(0)q(D|0), involving 
some prior q(@) and some likelihood q(D|@). If we happen to observe some particular dataset Do, 
how should we update our beliefs? We could search for the joint distribution that is as close as 
possible to our prior beliefs but that respects the constraint that we now know the value of the data: 


p0, D) = argmin Dt (p(@, D) || ¢(@, D)) such that p(D) = 6(D — Do). (5.70) 


where ô( D — Do) is a degenerate distribution that puts all its mass on the dataset D that is identically 
equal to Do. Writing the KL out in its chain rule form: 


Dut (p(9, D) || a9, D)) = Dri (p(P) || a(D)) + Dux (p(|P) || a(0|D)), (5.71) 
makes clear that the solution is given by the joint distribution: 
p(0, D) = p(D)p(6|D) = 6(D — Do)a(0|D). (5.72) 


Our updated beliefs have a marginal over the 0 


p(@) = [ave D)= fa 6(D — Do)a(0|D) = (0| D = Do), (5.73) 


which is just the usual Bayesian posterior from our prior beliefs evaluated at the data we observed. 
By contrast, the usual statement of Bayes’ rule is just a trivial observation about the chain rule of 
probabilities: 


q(D|9) 
q(D) 


Notice that this relates the conditional distribution q(@|D) in terms of q(D|0), q(@) and q(D), but 
that these are all different ways to write the same distribution. Bayes rule does not tell us how we 
ought to update our beliefs in light of evidence, for that we need some other principle [Cat+11]. 

One of the nice things about this interpretation of Bayesian inference is that it naturally generalizes 
to other forms of constraints rather than assuming we have observed the data exactly. 

If there was some additional measurement error that was well understood, we ought to instead of 
pegging out updated beliefs to be a delta function on the observed data, simply peg it to be the well 
understood distribution p(D). For example, we might not know the precise value the data takes, but 
believe after measuring things that it is a Gaussian distribution with a certain mean and standard 
deviation. 

Because of the chain rule of KL, this has no effect on our updated conditional distribution over 
parameters, which remains the Bayesian posterior: p(0|D) = q(0|D). However, this does change our 
marginal beliefs about the parameters, which are now: 


q(9, D) = q(D)4q(8|D) = 9(9)q(D|@) => q(6|D) = q(0). (5.74) 


(8) = f aD p(D)a(0\D). (5.75) 
This generalization of Bayes’ rule is sometimes called Jeffrey’s conditionalization rule [Cat08]. 
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5.1.7 KL divergence and Exponential Families 


The KL divergence between two exponential family distributions from the same family has a nice 
closed form, as we explain below. 
Consider p(a) with natural parameter 7, base measure h(a) and sufficient statistics 7 (æ): 


p(æ) = h(x) exp[n" T(x) — A(n)] (5.76) 


where 


A(n) = log | h(x) exp(n"T(a))de (5.77) 


is the log partition function, a convex function of 7. 
The KL divergence between two exponential family distributions from the same family is as follows: 
Deu (pæn) || p(#ln2)) = En, [(m — n2) T(E) — A(m) + Alm) (5.78) 
= (m — 12)" Hy — A(m) + A(n) (5.79) 


where pj = in, [T (@)]- 


5.1.7.1 Example: KL divergence between two Gaussians 


An important example is the KL divergence between two multivariate Gaussian distributions, which 
is given by 
Dri (N (z|u1, 21) || N (ælu, 22)) 
1 _ =! det (Xz) 
= $ [EERE + (ay -mE (a = a) = D+ Dog (SE (5.80) 


29 In the scalar case, this becomes 


o o? + — 2 1 
Dri (N (zlim, 01) || N(z|u2,02)) = log 2 + 4 a Ha 
O1 205 2 


(5.81) 


5.1.8 Bregman divergence 


~ Let f : Q — R be a continuously differentiable, strictly convex function defined on a closed convex 


~ set Q. We define the Bregman divergence associated with f as follows [Bre67]: 
B;(wl|v) = f(w) — f(v) — (w — v)' V f(v) (5.82) 
To understand this, let 
fo(w) = f(v) + (w—v)'Vf(v) (5.83) 


= be a first order Taylor series approximation to f centered at v. Then the Bregman divergence is the 
= difference from this linear approximation: 


B;(w||v) = f(w) — fo(w) (5.84) 
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5.1. KL DIVERGENCE 


gw) 


Wr 


gwi) + [Ve(w)] (w — wr) 


(a) (b) 


Figure 5.8: (a) Illustration of Bregman divergence. (b) A locally linear approximation to a non-convex 
function. 


See Figure 5.3a for an illustration. Since f is convex, we have By(w||v) > 0, since fy is a linear 
lower bound on f. 
Below we mention some important special cases of Bregman divergences. 


e If f(w) = ||w||?, then By(w||v) = ||w — v||? is the squared Euclidean distance. 
e If f(w) = w' Qu, then By(w||v) is the squared Mahalanobis distance. 


e If w are the natural parameters of an exponential family distribution, and f(w) = log Z(w) is 
the log normalizer, then the Bregman divergence is the same as the Kullback Leibler divergence, 
as we show in Section 5.1.8.1. 


5.1.8.1 KL is a Bregman divergence 


Recall that the log partition function A(7) is a convex function. We can therefore use it to define 
the Bregman divergence (Section 5.1.8) between the two distributions, p and q, as follows: 


Bs(q\lMp) = A(n) = A(n,) z (n4 — np) Vn, A(Mp) (5.85) 
= A(nq) — A(np) — (Mq — Np)" Ep [T(2)] (5.86) 
= Dri (p || 4) (5.87) 


where we exploited the fact that the gradient of the log partition function computes the expected 
sufficient statistics as shown in Section 2.3.3. 

In fact, the KL divergence is the only divergence that is both a Bregman divergence and an 
f-divergence (Section 2.7.1) [Ama09]. 
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5.2 Entropy 


In this section, we discuss the entropy of a distribution p, which is just a shifted and scaled version 
of the KL divergence between the probability distribution and the uniform distribution, as we will 
see. 


5.2.1 Definition 
The entropy of a discrete random variable X with distribution p over K states is defined by 


yee Sa k) log p(X = k) = -Ex [logp(X)] (5.88) 
We can use logarithms to any base, but we commonly use log base 2, in which case the units are 
called bits, or log base e, in which case the units are called nats, as we explained in Section 5.1.3.1. 
The entropy is equivalent to a constant minus the KL divergence from the uniform distribution: 
H(X) = - — Dr (p(X) || u(X)) (5.89) 
X=k 
Dri (p(X) || w(X -Sax = k)log Z ZH (5.90) 
K 

K 
=log K + p(X = k) log p(X = k) (5.91) 

k=1 


If p is uniform, the KL is zero, and we see that the entropy achieves its maximal value of log K. 
For the special case of binary random variables, X € {0,1}, we can write p(X = 1) = 8 and 


p(X =0) = 1—0. Hence the entropy becomes 
H(X) = —[p(X = 1) log p(X = 1) + p(X = 0) log p(X = 0)] (5.92) 
— [0 log @ + (1 — 6) log(1 — 8)] (5.93) 


This is called the binary entropy function, and is also written H (0). We plot this in Figure 5.4. 
We see that the maximum value of 1 bit occurs when the distribution is uniform, 6 = 0.5. A fair coin 
requires a single yes/no question to determine its state. 


5.2.2 Differential entropy for continuous random variables 


= If X is a continuous random variable with pdf p(x), we define the differential entropy as 


h(X) ê - | dx p(e) ioga) (5.94) 


assuming this integral exists. 
For example, one can show that the entropy of a d-dimensional Gaussian is 


1 1 d d i 
h(N(p, =)) = z 08 |27eD| = 5 log|[(27e)*|E|] = aa log(2m) + 5 log || (5.95) 
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5.2. ENTROPY 


1.0 


0.0 
0.0 0.5 1.0 


Figure 5.4: Entropy of a Bernoulli random variable as a function of 0. The maximum entropy is log, 2 = 1. 
Generated by bernoulli_ entropy fig.ipynb. 


In the 1d case, this becomes 
1 
h(N(p,07)) = 5 log [2reo?] (5.96) 


Note that, unlike the discrete case, differential entropy can be negative. This is because pdf’s can 
be bigger than 1. For example, suppose X ~ U(0,a). Then 


mS | 1 
h(X) = -f dx — log — = loga (5.97) 
0 a a 


If we set a = 1/8, we have h(X) = log,(1/8) = —3 bits. 

One way to understand differential entropy is to realize that all real-valued quantities can only be 
represented to finite precision. It can be shown [CT91, p228] that the entropy of an n-bit quantization 
of a continuous random variable X is approximately h(X) +n. For example, suppose X ~ U(0, 5): 
Then in a binary representation of X, the first 3 bits to the right of the binary point must be 0 (since 
the number is < 1/8). So to describe X to n bits of accuracy only requires n — 3 bits, which agrees 
with h(X) = —3 calculated above. 

The continuous entropy also lacks the reparameterization independence of KL divergence Sec- 
tion 5.1.2.3. In particular, if we transform our random variable y = f(x), the entropy transforms. To 
see this, note that the change of variables tells us that 


d 
ply) dy = plo) de = ply) = ple) || > (5.98) 
Thus the continuous entropy transforms as follows: 
dy 
h(X) =- | dx p(x) logp(z) = h(Y) — | dyp(y) log |). (5.99) 


We pick up a factor in the continuous entropy of the log of the determinant of the Jacobian of the 
transformation. This changes the value for the continuous entropy even for simply rescaling the 
random variable such as when we change units. For example in Figure 5.5 we show the distribution 
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Height (ft) 
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140 160 180 200 220 
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Figure 5.5: Distribution of adult heights. The continuous entropy of the distribution depends on its units 
of measurement. If heights are measured in feet, this distribution has a continuous entropy of 0.43 bits. 
If measured in centimeters it’s 5.4 bits. If measured in meters it’s -1.3 bits. Data taken from https: 
// ourworldindata. org/ human-height . 


of adult human heights (it is bimodal because while both male and female heights are normally 
distributed, they differ noticeably). The continous entropy of this distribution depends on the 
units it is measured in. If measured in feet, the continuous entropy is 0.43 bits. Intuitively this is 
because human heights mostly span less than a foot. If measured in centimeters it is instead 5.4 bits. 


27 There are 30.48 centimeters in a foot, log, 30.48 = 4.9 explaining the difference. If we measured the 
28 continuous entropy of the same distribution measured in meters we would obtain -1.3 bits! 


27 5.2.3 Typical sets 


= The typical set of a probability distribution is the set whose elements have an information content 
= that is close to that of the expected information content from random samples from the distribution. 
22 More precisely, for a distribution p(x) with support x € ¥, the etypical set AN € XN for p(x) is 


= the set of all length N sequences such that 
1 
H(p(a)) —e< “HW log p(a1,...,@n) < H(p(a)) +€ (5.100) 
42 If we assume p(£1,..., £N) = Te p(a,), then we can interpret the term in the middle as the 


43 N-sample empirical estimate of the entropy. The asymptotic equipartition property or AEP 


44 states that this will converge (in probability) to the true entropy as N — co [CT06]. Thus the typical 
45 set has probability close to 1, and is thus a compact summary of what we can expect to be generated 
46 by p(x). 
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5.2. ENTROPY 


5.2.4 Cross entropy and perplexity 


A standard way to measure how close a model q is to a true distribution p is in terms of the KL 
divergence (Section 5.1), given by 


p(x) 
Dra (p || @) = 2 ple) jipa ga) 7 Hee Pa) -H (p) (5.101) 


where Hee (p, q) is the cross entropy 


Hee (p,q mae ) log q(x (5.102) 


and H (p) = Hee (p, p) is the entropy, which is a constant independent of the model. 
In language modeling, it is common to report an alternative performance measure known as the 
perplexity. This is defined as 


perplexity (p, q) & 2e (2-9) (5.103) 


We can compute an empirical approximation to the cross entropy as follows. Suppose we approxi- 
mate the true distribution with an empirical distribution based on data sampled from p: 


N 
p(z|D) = WLM (5.104) 
In this case, the cross entropy is given by 
1 Š BoT 
ea 2 los (en) =e los [] pen) (5.105) 


The corresponding perplexity is given by 


perplexity (pp, p) = o7 F log p(@n)) — glog p(en))~ N (5.106) 


(5.107) 


In the case of language models, we usually condition on previous words when predicting the next 
word. For example, in a bigram model, we use a second order Markov model of the form p(£n|£n-1). 
We define the branching factor of a language model as the number of possible words that can 
follow any given word. For example, suppose the model predicts that each word is equally likely, 
regardless of context, so P( <n |e — 1) =1/K, where K is the number of words in the vocabulary. Then 
the perplexity is ((1/K)%)~'/N = K. If some symbols are more likely than others, and the model 
correctly reflects this, its perplexity will be lower than K. However, we have H (p*) < Hee (p*, p), so 
we can never reduce the perplexity below 27 H(p*), 
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5.3 Mutual information 


The KL divergence gave us a way to measure how similar two distributions were. How should we 
measure how dependant two random variables are? One thing we could do is turn the question 
of measuring the dependence of two random variables into a question about the similarity of their 
distributions. This gives rise to the notion of mutual information (MI) between two random 
variables, which we define below. 


5.3.1 Definition 


The mutual information between rv’s X and Y is defined as follows: 


1(X;¥) 4 Dux (p(2,) laea = a E rey) oe fet (5.108) 


(We write I(X;Y) instead of I(X,Y), in case X and/or Y represent sets of variables; for example, we 
can write I(X;Y, Z) to represent the MI between X and (Y, Z).) For continuous random variables, 
we just replace sums with integrals. 

It is easy to see that MI is always non-negative, even for continuous random variables, since 


I(X;Y) = Dri (p(2,y) || p(@)p(y)) = 0 (5.109) 


We achieve the bound of 0 iff p(x, y) = p(x)p(y). 


5.3.2 Interpretation 


Knowing that the mutual information is a KL divergence between the joint and factored marginal 
distributions tells is that the MI measures the information gain if we update from a model that treats 


~~ the two variables as independent p(x)p(y) to one that models their true joint density p(x, y). 


To gain further insight into the meaning of MI, it helps to re-express it in terms of joint and 


— conditional entropies, as follows: 


I(X;Y) =H(X) —H(X|Y) =H(Y) -H(Y|X) (5.110) 


Thus we can interpret the MI between X and Y as the reduction in uncertainty about X after 


35 observing Y, or, by symmetry, the reduction in uncertainty about Y after observing X. Incidentally, 


this result gives an alternative proof that conditioning, on average, reduces entropy. In particular, we 


37 have 0 < I(X;Y) = H(X) —H(X|Y), and hence H (X|Y) < H(X). 


We can also obtain a different interpretation. One can show that 


I(X;Y) =H(X,Y)-H(X|Y)-H(Y|X) (5.111) 


— Finally, one can show that 


I(X:Y)=H(X)+H(Y)-H(X,Y) (5.112) 


45 See Figure 5.6 for a summary of these equations in terms of an information diagram. (Formally, 
46 this is a signed measure mapping set expressions to their information-theoretic counterparts [Yeu91a].) 
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5.3. MUTUAL INFORMATION 


Entropy Joint Entropy 


ea 
©, cp 


W(X) > % HOY H(x,¥) S XUY 
Mutual Information Conditional Entropy 
xX Y X y 
IX D Xay H(xlY) =) X-Y HYO Y-X 


Figure 5.6: The marginal entropy, joint entropy, conditional entropy and mutual information represented as 
information diagrams. Used with kind permission of Katie Everett. 


5.3.3 Data processing inequality 


Suppose we have an unknown variable X, and we observe a noisy function of it, call it Y. If we 
process the noisy observations in some way to create a new variable Z, it should be intuitively obvious 
that we cannot increase the amount of information we have about the unknown quantity, X. This is 
known as the data processing inequality. We now state this more formally, and then prove it. 


Theorem 5.3.1. Suppose X + Y —> Z forms a Markov chain, so that X L Z|Y. Then I(X;Y) > 
I1(X;Z). 


Proof. By the chain rule for mutual information we can expand the mutual information in two 
different ways: 


I(X;Y, Z) =1(X;Z)+1(X;Y|Z) (5.113) 
=1(X;Y)+1(X; Z|Y) (5.114) 


Since X L Z|Y, we have I(X;Z|Y) = 0, so 
1(X;Z)+1(X;Y|Z) =1(X;Y) (5.115) 


Since I(X;Y |Z) > 0, we have I(X;Y) >1(X; Z). Similarly one can prove that I (Y; Z) > 1(X; Z). 
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5.3.4 Sufficient Statistics 


An important consequence of the DPI is the following. Suppose we have the chain 6 > X > s(X). 
Then 


1 (6; s(X)) <1(6;X) (5.116) 


If this holds with equality, then we say that s(X) is a sufficient statistic of the data X for the 
purposes of inferring 0. In this case, we can equivalently write 0 > s(X) — X, since we can 
reconstruct the data from knowing s(X) just as accurately as from knowing 6. 

An example of a sufficient statistic is the data itself, s(X) = X, but this is not very useful, since it 
doesn’t summarize the data at all. Hence we define a minimal sufficient statistic s(X) as one 
which is sufficient, and which contains no extra information about 0; thus s(X) maximally compresses 
the data X without losing information which is relevant to predicting 0. More formally, we say s is a 
minimal sufficient statistic for X if s(X) = f(s’(X)) for some function f and all sufficient statistics 
s'(X). We can summarize the situation as follows: 


6 > 8(X) > 8'(X) 3 X (5.117) 


Here s’(X) takes s(X) and adds redundant information to it, thus creating a one-to-many mapping. 

For example, a minimal sufficient statistic for a set of N Bernoulli trials is simply N and N; = 
>, 1 (Xn = 1), i.e., the number of successes. In other words, we don’t need to keep track of the 
entire sequence of heads and tails and their ordering, we only need to keep track of the total number 
of heads and tails. Similarly, for inferring the mean of a Gaussian distribution with known variance 
we only need to know the empirical mean and number of samples. 

Earlier in Section 5.1.7 we motivated the exponential family of distributions as being the ones that 
are minimal in the sense that they contain no other information than constraints on some statistics of 
the data. It makes sense then that the statistics used to generate exponential family distributions are 
sufficient. It also hints at the more remarkable fact of the Pitman-Koopman-Darmois theorem, 
which says that for any distribution whose domain is fixed, it is only the exponential family that 
admits sufficient statistics with bounded dimensionality as the number of samples increases Diaconis 


©, [Dia88]. 


5.3.5 Multivariate mutual information 


= There are several ways to generalize the idea of mutual information to a set of random variables as 


= we discuss below. 


— 5.3.5.1 Total correlation 


39 The simplest way to define multivariate MI is to use the total correlation [Wat60] or multi- 
40 information [SV98], defined as 


TC({X1,...,Xp}) = De (ne |l Treo) (5.118) 
d 


PT) SH (24) -H(z 5.119 
ii 2 (za) (x) ( ) 
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Figure 5.7: Illustration of multivariate mutual information between three random variables. From https: 
// en. wikipedia. org/wiki/Mutual_ information. Used with kind permission of Wikipedia author PAR. 


For example, for 3 variables, this becomes 
TC(X,Y,Z) =H(X)+H(Y)+H(Z) -H(X,Y,Z) (5.120) 


where H (X,Y, Z) is the joint entropy 


H (X,Y, Z) =~ Ld dw x,y, z) log p(x, y, z) (5.121) 


One can show that the multi-information is always non-negative, and is zero iff p(a) = Į [4 p(za). 
However, this means the quantity is non-zero even if only a pair of variables interact. For example, if 
p(X, Y, Z) = p(X, Y)p(Z), then the total correlation will be non-zero, even though there is no 3 way 
interaction. This motivates the alternative definition in Section 5.3.5.2. 


5.3.5.2 Interaction information (co-information) 


The conditional mutual information can be used to give an inductive definition of the multivariate 
mutual information (MMI) as follows: 


(Xi; Xp) =1(%4;--- 5 Xp_-1) — W(X; --- 5 XD_-1|XD) (5.122) 


This is called the multiple mutual information [Yeu91b], or the co-information [Bel03]. This 
definition is equivalent, up to a sign change, to the interaction information [McG54; Han80; JB03; 
Bro09]. 

For 3 variables, the MMI is given by 


I(X;Y; Z) =1(X;Y) -I(X;Y|Z) (5.123) 
= I(X; Z) -1(X;Z|Y) (5.124) 
=I(Y; Z) - I(Y; Z|X) (5.125) 
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This can be interpreted as the change in mutual information between two pairs of variables when 
conditioning on the third. Note that this quantity is symmetric in its arguments. 
By the definition of conditional mutual information, we have 


I(X; ZIY) = I(Z; X,Y) — I(Y; Z) (5.126) 
Hence we can rewrite Equation (5.124) as follows: 
I(X;Y; Z) =1(X;Z) + I(Y; Z) - I(X,Y; Z) (5.127) 


This tells us that the MMI is the difference between how much we learn about Z given X and Y 
individually vs jointly (see also Section 5.3.5.3). 

The 3-way MMI is illustrated in the information diagram in Figure 5.7. The way to interpret such 
diagrams when we have multiple variables is as follows: the area of a shaded area that includes circles 
A, B,C,... and excludes circles F, G, H,... represents I(A; B; C; ... |F, G, H,...); if B = C =9, this 
is just H(A|F,G,H,...); if F = G = H = 9, this is just 1(A; B;C,...). 


5.3.5.3 Synergy and redundancy 


The MMI is I(X;Y; Z) = 1(X; Z) + I(Y; Z) —1(X,Y; Z). We see that this can be positive, zero or 
negative. If some of the information about Z that is provided by X is also provided by Y, then 
there is some redundancy between X and Y (wrt Z). In this case, I(X; Z) + I(Y; Z) > (X,Y; Z), 
so (from Equation (5.127)) we see that the MMI will be positive. If, by contrast, we learn more 
about Z when we see X and Y together, we say there is some synergy between them. In this case, 
W(X; Z) + 1(Y;Z) < 1(X,Y; Z), so the MMI will be negative. 


= 5.3.5.4 MMI and causality 


31 The sign of the MMI can be used to distinguish between different kinds of directed graphical models, 
32 which can sometimes be interpreted causally (see Chapter 36 for a general discussion of causality). 
33 For example, consider a model of the form X «+ Z — Y, where Z is a “cause” of X and Y. For 


example, suppose X represents the event it is raining, Y represents the event that the sky is dark, 


35 and Z represents the event that the sky is cloudy. Conditioning on the common cause Z renders 
36 the children X and Y independent, since if I know it is cloudy, noticing that the sky is dark does 
37 not change my beliefs about whether it will rain or not. Consequently I(X;Y|Z) < I(X;Y), so 
38 I(X;Y; Z) > 0. 


Now consider the case where Z is a common effect, X => Z + Y. In this case, conditioning on 


40 Z makes X and Y dependent, due to the explaining away phenomenon (see Section 4.2.4.2). For 
41 example, if X and Y are independent random bits, and Z is the XOR of X and Y, then observing 
42 Z = 1 means that p(X # Y|Z = 1) = 1, so X and Y are now dependent (information-theoretically, 
43 not causally), even though they were a priori independent. Consequently I(X;Y|Z) > I(X;Y), so 


I(X;Y; Z) <0. 
Finally, consider a Markov chain, X > Y —> Z. We have I(X; Z|Y) < I(X; Z) and so the MMI 


46 must be positive. 
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5.3.5.5 MMI and entropy 


We can also write the MMI in terms of entropies. Specifically, we know that 


(x;Y) =H(X)+H(Y)-H(X,Y) (5.128) 
and 
I(X;Y|Z) =H(X, 7) +H(Y, Z) -H(Z) -H(X,Y, Z) (5.129) 


Hence we can rewrite Equation (5.123) as follows: 
I(X;Y; Z) = [H(X) + H(Y)+H(Z)] -[H(X,Y)+H(X,Z)+H (Y, Z) +H(X,Y,Z) (5.130) 


Contrast this to Equation (5.120). 
More generally, we have 


(Xi. Xp) =- X (-1)!7 HT) (5.131) 


For sets of size 1, 2 and 3 this expands as follows: 


Hem (5.132) 
h= H+ = ie (5.133) 
Iı23 = Hı + Hə + H; — Hi2 — Hi3 — B23 + A123 (5.134) 


We can use the Mobius inversion formula to derive the following dual relationship: 


H (S) =- X (-1)|711(7) (5.135) 


TCS 


for sets of variables S. 
Using the chain rule for entropy, we can also derive the following expression for the 3-way MMI: 


I(X;Y; Z) =H(Z) —H(Z|X) —H(Z|Y) +H(Z|X,Y) (5.136) 


5.3.6 Variational bounds on mutual information 


In this section, we discuss methods for computing upper and lower bounds on MI that use variational 
approximations to the intractable distributions. This can be useful for representation learning 
(Chapter 32). This approach was first suggested in [BAO3]. For a more detailed overview of 
variational bounds on mutual information, see Poole et al. [Poo+19b]. 


5.3.6.1 Upper bound 


Suppose that the joint p(x, y) is intractable to evaluate, but that we can sample from p(x) and 
evaluate the conditional distribution p(y|x). Furthermore, suppose we approximate p(y) by q(y). 
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Then we can compute an upper bound on the MI as follows: 


Ma; y) = Epey) [os p(y)ay) 


“ne per- 


Up (yl) og 


p(x) 


a 


This bound is tight if q(y) = p(y). 


p(ylx)a(y) 


| 


Dux (p(y) || a(y)) 
al 


q(y) 
q 


(y))] 


(5.137) 
(5.138) 


(5.139) 


(5.140) 


What’s happening here is that I (Y; X) = H(Y) — H(Y|X) and we’ve assumed we know p(y|x) 
and so can estimate H (Y |X) well. While we don’t know H (Y), we can upper bound it using some 
model g(y). Our model can never do better than p(y) itself (the non-negativity of KL), so our 


entropy estimate errs too large, and hence our MI estimate will be an upper bound. 


5.3.6.2 BA lower bound 


Suppose that the joint p(a, y) is intractable to evaluate, but that we can evaluate p(x). Furthermore, 
suppose we approximate p(aly) by q(a|y). Then we can derive the following variational lower bound 


on the mutual information: 


7 DE 
Tv) = Eue oe | 


IV 
= 
8 
= 
mj 
w 
O 
[oje] 
— 
8 
WH 
| 
II 


Barber and Agakov [BA03]. 


39 5.3.6.3 NWJ lower bound 


p(y) [Dri (p(æly) || aæly))] 


p(@,y) [log q(a|y)] + h(x) 


(5.141) 
(5.142) 


(5.143) 


3, Where h(x) is the differential entropy of x. This is called the BA lower bound, after the authors 


The BA lower bound requires a tractable normalized distribution q(x|y) that we can evaluate 
pointwise. If we reparameterize this distribution in a clever way, we can generate a lower bound that 


does not require a normalized distribution. Let’s write: 
plæ)jef =y) 
a(zly) = —z 
Z(y) 
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with Z(y) = Eye) [ef (æ4)] the normalization constant or partition function. Plugging this into the 
BA lower bound above we obtain: 


xef (@y) 
"p(æ,y) ow ian = Eye,y) [f (£, ¥)] — Ep [2(y)] (5.145) 
= Epey) [f (£, y)] — Ep) [log Un (a) jefe] (5.146) 
= Ipv(X;Y). (5.147) 


This is the Donsker Varadhan lower bound [DV75]. 
We can construct a more tractable version of this by using the fact that the log function can be 
upper bounded by a straight line using 
x 
logz < —+loga-—1 (5.148) 
a 


If we set a = e, we get 


I(X;Y) > Ene vli (£ y)] — e Epi) Z(y) ê Ivwa(XsY) (5.149) 


This is called the NWJ lower bound (after the authors of Nguyen, Wainwright, and Jordan 
[NWJ10a]), or the {GAN KL [NCT 16a], or the MINE-f score [Bel+ 18]. 
5.3.6.4 InfoNCE lower bound 


If we instead explore a multi-sample extension to the DV bound above, we can generate the following 
lower bound (see [Poo+19b] for the derivation): 


K aia 
1 ef (ziyi) 
Ince =E |= $ log K (5.150) 
K = + Dja ef (ziyi) 
1 K K 
= Lm f(xiys)—f (æi yi) 
=lgK-E K 2,18 ae y y (5.151) 


where the expectation is over paired samples from the joint p(X, Y). The quantity in Equation (5.151) 
is called the InfoNCE estimate, and was proposed in [OLV18a; Hen+19a]. (NCE stands for “noise 
contrastive estimation”, and is discussed in Section 24.4.) 

The intuition here is that mutual information is a divergence between the joint p(x, y) and the 
product of the marginals, p(x)p(y). In other words, mutual information is a measurement of how 
distinct sampling pairs jointly is from sampling as and ys independently. The InfoNCE bound 
provides a lower bound on the true mutual information by attempting to train a model to distinguish 
between these two situations. 

Although this is a valid lower bound, we may need to use a large batch size K to estimate the 
MI if the MI is large, since Incr < log K. (Recently [SE20a] proposed to use a multi-label classifier, 
rather than a multi-class classifier, to overcome this limitation.) 
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documents 


40 50 
words 


Figure 5.8: Subset of size 16242 x 100 of the 20-newsgroups data. We only show 1000 rows, for clarity. Each 
row is a document (represented as a bag-of-words bit vector), each column is a word. The red lines separate 
the 4 classes, which are (in descending order) comp, rec, sci, talk (these are the titles of USENET groups). 
We can see that there are subsets of words whose presence or absence is indicative of the class. The data is 
available from http: //cs. nyu. edu/ “roweis/ data. html. Generated by newsgroups visualize.ipynb. 


27 Figure 5.9: Part of a relevance network constructed from the 20 newsgroup data. data shown in Figure 5.8. 
28 We show edges whose mutual information is greater than or equal to 20% of the maximum pairwise MI. 


For clarity, the graph has been cropped, so we only show a subset of the nodes and edges. Generated by 
relevance_network_newsgroup_ demo.ipynb. 


= 5.3.7 Relevance networks 


If we have a set of related variables, we can compute a relevance network, in which we add an i— j 


36 edge if the pairwise mutual information I (X;;X;) is above some threshold. In the Gaussian case, 


I (Xi; Xj) = —4 log(1 — AE where ;; is the correlation coefficient, and the resulting graph is called 


38 a covariance graph (Section 4.5.5.1). However, we can also apply it to discrete random variables. 


Relevance networks are quite popular in systems biology [Mar-+06], where they are used to visualize 
the interaction between genes. But they can also be applied to other kinds of datasets. For example, 


41 Figure 5.9 visualizes the MI between words in the 20 newsgroup dataset shown in Figure 5.8. The 


results seem intuitively reasonable. 
However, relevance networks suffer from a major problem: the graphs are usually very dense, since 


44 most variables are dependent on most other variables, even after thresholding the MIs. For example, 


suppose X, directly influences Xə which directly influences X3 (e.g., these form components of a 


46 signalling cascade, Xı — Xə — X3). Then X, has non-zero MI with X3 (and vice versa), so there 
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will be a 1 — 3 edge as well as the 1 — 2 and 2 — 3 edges; thus the graph may be fully connected, 
depending on the threshold. 

A solution to this is to learn a probablistic graphical model, which represents conditional in- 
dependence, rather than dependence. In the chain example, there will not be a 1 — 3 edge, since 
Xı L X3|X2. Consequently graphical models are usually much sparser than relevance networks. See 
Chapter 30 for details. 


5.4 Data compression (source coding) 


Data compression, also known as source coding, is at the heart of information theory. It is also 
related to probabilistic machine learning. The reason for this is as follows: if we can model the 
probability of different kinds of data samples, then we can assign short code words to the most 
frequently occuring ones, reserving longer encodings for the less frequent ones. This is similar to 
the situation in natural language, where common words (such as “a”, “the”, “and”) are generally 
much shorter than rare words. Thus the ability to compress data requires an ability to discover 
the underlying patterns, and their relative frequencies, in the data. This has led Marcus Hutter 
to propose that compression be used as an objective way to measure performance towards general 
purpose AI. More precisely, he is offering 50,000 Euros to anyone who can compress the first 100MB 
of (English) Wikipedia better than some baseline. This is known as the Hutter prize.' 

In this section, we give a brief summary of some of the key ideas in data compression. For details, 
see e.g., [Mac03; CT06; YMT22]. 


5.4.1 Lossless compression 


Discrete data, such as natural language, can always be compressed in such a way that we can uniquely 
recover the original data. This is called lossless compression. 

Claude Shannon proved that the expected number of bits needed to losslessly encode some data 
coming from distribution pis at least H (p). This is known as the source coding theorem. Achieving 
this lower bound requires coming up with good probability models, as well as good ways to design 
codes based on those models. Because of the non-negativity of the KL divergence, H..(p, q) > H(p), 
so if we use any model q other than the true model p to compress the data, it will take some excess 
bits. The number of excess bits is exactly Dx (p || q). 

Common techniques for realizing lossless codes include Huffman coding, arithmetic coding and 
asymmetric numeral systems [Dud13]. The input to these algorithms is a probability distribution 
over strings (which is where ML comes in). This distribution is often represented using a latent 
variable model (see e.g., [TBB19; KAH19]). 


5.4.2 Lossy compression and the rate-distortion tradeoff 


To encode real-valued signals, such as images and sound, as a digital signal, we first have to quantize 
the signal into a sequence of symbols. A simple way to do this is to use vector quantization. We can 
then compress this discrete sequence of symbols using lossless coding methods. However, when we 
uncompress, we lose some information. Hence this approach is called lossy compression. 


1. For details, see http: //prize.hutter1.net. 
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In this section, we quantify this tradeoff between the size of the representation (number of symbols 
we use), and the resulting error. We will use the terminology of the variational information bottleneck 
discussed in Section 5.6.2 (except here we are in the unsupervised setting). In particular, we assume 
we have a stochastic encoder p(z|x), a stochastic decoder d(x|z) and a prior marginal m(z). 

We define the distortion of an encoder-decoder pair (as in Section 5.6.2) as follows: 


- f àx pla) f az e(z|x)logd(æ|z) (5.152) 


If the decoder is a deterministic model plus Gaussian noise, d(a|z) = N (æ| fa(z), 07), and the encoder 
is deterministic, e(z|x) = 6(z — f-(x)), then this becomes 


D = Eye) [|| fal fe(@)) — ||") (5.153) 


on 


This is just the expected reconstruction error that occurs if we (deterministically) encode and 
then decode the data using fe and fa. 
We define the rate of our model as follows: 


R= po pla) f az e(z|x)log ae (5.154) 


= Eye) [Dri (e(z|æ) || m(z))] (5.155) 


p(z, z) 
fè f azp, z) log ——— playmiz) = > I(x, z) (5.156) 


This is just the average KL between our encoding distribution and the marginal. If we use m(z) to 
design an optimal code, then the rate is the excess number . bits we i to pay to encode our data 
using m(z) rather than the true aggregate posterior p(z) = f dæ p(x)e(z|x). 

There is a fundamental tradeoff between the rate and a To see why, note that a trivial 


2 encoding scheme would set e(z|æ) = 6(z — x), which simply uses æ as its own best representation. 
°° This would incur 0 distortion (and hence maximize the likelihood), but it would incur a high rate, 
31 since each e(z|a) distribution would be unique, and far from m(z). In other words, there would be 
32 no compression. Conversely, if e(z|a) = 6(z — 0), the encoder would ignore the input. In this case, 
232 the rate would be 0, but the distortion would be high. 


We can characterize the tradeoff more precisely using the variational lower and upper bounds on 


2 the mutual information from Section 5.3.6. From that section, we know that 


H — D < I(x;z)< R (5.157) 


where H is the (differential) entropy 


H = - f àx p(a) log p(a) (5.158) 


43 For discrete data, all probabilities are bounded above by 1, and hence H > 0 and D > 0. In addition, 


the rate is always non-negative, R > 0, since it is the average of a KL divergence. (This is true for 


45 either discrete or continuous encodings z.) Consequently, we can plot the set of achievable values of 
46 Rand D as shown in Figure 5.10. This is known as a rate distortion curve. 
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Figure 5.10: Illustration of the rate-distortion tradeoff. See text for details. From Figure 1 of [Ale +18]. Used 
with kind permission of Alex Alemi. 


The bottom horizontal line corresponds to the zero distortion setting, D = 0, in which we can 
perfectly encode and decode our data. This can be achieved by using the trivial encoder where 
e(z|a) = 6(z — x). Shannon’s source coding theorem tells us that the minimum number of bits we 
need to use to encode data in this setting is the entropy of the data, so R > H when D = 0. If we 
use a suboptimal marginal distribution m(z) for coding, we will increase the rate without affecting 
the distortion. 

The left vertical line corresponds to the zero rate setting, R = 0, in which the latent code is 
independent of z. In this case, the decoder d(a|z) is independent of z. However, we can still learn a 
joint probability model p(x) which does not use latent variables, e.g., this could be an autoregressive 
model. The minimal distortion such a model could achieve is again the entropy of the data, D > H. 

The black diagonal line illustrates solutions that satisfy D = H — R, where the upper and lower 
bounds are tight. In practice, we cannot achieve points on the diagonal, since that requires the 
bounds to be tight, and therefore assumes our models e(z|a) and d(a|z) are perfect. This is called 
the “non-parametric limit”. In the finite data setting, we will always incur additional error, so the 
RD plot will trace a curve which is shifted up, as shown in Figure 5.10. 

We can generate different solutions along this curve by minimizing the following objective: 


J=D+6R= fa ple) | dz e(z|a) |- log d(a|z) + Blog = (5.159) 
If we set 6 = 1, and define q(z|x) = e(z|æ), p(a|z) = d(x|z), and p(z) = m(z), this exactly matches 
the VAE objective in Section 21.2. To see this, note that the ELBO from Section 10.1.2 can be 
written as 


Ł=-(D + R) = Up (aw) | Se(z|æ) [log d(æ|z)] — le (z|ar) ow al (5.160) 


which we recognize as the expected reconstruction error minus the KL term Dgr (e(z|«) || m(z)). 
If we allow 8 Æ 1, we recover the 6-VAE objective discussed in Section 21.3.1. Note, however, that 
the 6-VAE model cannot distinguish between different solutions on the diagonal line, all of which have 
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6 = 1. This is because all such models have the same marginal likelihood (and hence same ELBO), 
although they differ radically in terms of whether they learn an interesting latent representation or 
not. Thus likelihood is not a sufficient metric for comparing the quality of unsupervised representation 
learning methods, as discussed in Section 21.3.1. 

For further discussion on the inherent conflict between rate, distortion and perception, see Blau 
and Michaeli [BM19]. For techniques for evaluating rate distortion curves for models see Huang, Cao, 
and Grosse [HCG20]. 


5.4.3 Bits back coding 


In the previous section we penalized the rate of our code using the average KL divergence, Ep) [R(x)], 
where 


R(x) & f deple|x)log 2? — H,.(o(zle),m(2)) - Hele) (5.161) 
The first term is the cross entropy, which is the expected number of bits we need to encode a; the 
second term is the entropy, which is the minimum number of bits. Thus we are penalizing the excess 
number of bits required to communicate the code to a receiver. How come we don’t have to “pay for” 
the actual (total) number of bits we use, which is the cross entropy? 

The reason is that we could in principle get the bits needed by the optimal code given back to 
us; this is called bits back coding [HC93; FH97]. The argument goes as follows. Imagine Alice is 
trying to (losslessly) communicate some data, such as an image x, to Bob. Before they went their 
separate ways, both Alice and Bob decided to share their encoder p(z|a), marginal m(z) and decoder 
distributions d(a|z). To communicate an image, Alice will use a two part code. First, she will 


28 sample a code z ~ p(z|a) from her encoder, and communicate that to Bob over a channel designed 
27 to efficiently encode samples from the marginal m(z); this costs — log, m(z) bits. Next Alice will 
28 use her decoder d(a|z) to compute the residual error, and losslessly send that to Bob at the cost of 
2 — log, d(æ|z) bits. The expected total number of bits required here is what we naively expected: 


p(zla) [ logy d(a|z) — log, m(z)] = D + Hee(p(z|x), m(z)). (5.162) 


33 We see that this is the distortion plus cross entropy, not distortion plus rate. So how do we get the 
34 bits back, to convert the cross entropy to a rate term? 


The trick is that Bob actually receives more information than we suspected. Bob can use the code 


36 z and the residual error to perfectly reconstruct x. However, Bob also knows what specific code Alice 
37 sent, z, as well as what encoder she used, p(z|a). When Alice drew the sample code z ~ p(z|a), she 
38 had to use some kind of entropy source in order to generate the random sample. Suppose she did it 
39 by picking words sequentially from a compressed copy of Moby Dick, in order to generate a stream 
40 of random bits. On Bob’s end, he can reverse engineer all of the sampling bits, and thus recover the 
41 compressed copy of Moby Dick! Thus Alice can use the extra randomness in the choice of z to share 
42 more information. 


While in the original formulation the bits-back argument was largely theoretical, offering a thought 


44 experiment for why we should penalize our models with the KL instead of the cross entropy, recently 
45 several practical real world algorithms have been developed that actually achieve the bits-back goal. 
46 These include [HHLMF18; AT20; TBB19; YBM20; HLA19]. 
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5.5. ERROR-CORRECTING CODES (CHANNEL CODING) 


p(x|y=1,0,0) 


0.30 
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yı Y2 Y3 
(a) (b) 


011 


Figure 5.11: (a) A simple error-correcting code DPGM. x; are the sent bits, yi are the received bits. x3 is an 
even parity check bit computed from x1 and x2. (b) Posterior over codewords given that y = (1,0,0); the 
probability of a bit flip is 0.2. Generated by error_ correcting code _ demo.ipynb. 


5.5 Error-correcting codes (channel coding) 


The idea behind error correcting codes is to add redundancy to a signal x (which is the result of 
encoding the original data), such that when it is sent over to the receiver via a noisy transmission 
line (such as a cell phone connection), the receiver can recover from any corruptions that might occur 
to the signal. This is called channel coding. 

In more detail, let x € {0,1} be the source message, where m is called the block length. Let 
y be the result of sending x over a noisy channel. This is a corrupted version of the message. 
For example, each message bit may get flipped independently with probability a, in which case 
p(y|x) = TT, p(yilzi), where p(y;|z; = 0) = [1 — a, a} and p(y;|z; = 1) = [a, 1 — a]. Alternatively, 
we may add Gaussian noise, so p(y;|x; = b) = N (yi|ub, 07). The receiver’s goal is to infer the true 
message from the noisy observations, i.e., to compute argmax,, p(a|y). 

A common way to increase the chance of being able to recover the original signal is to add parity 
check bits to it before sending it. These are deterministic functions of the original signal, which 
specify if the sum of the input bits is odd or even. This provides a form of redundancy, so that if 
one bit is corrupted, we can still infer its value, assuming the other bits are not flipped. (This is 
reasonable since we assume the bits are corrupted independently at random, so it is less likely that 
multiple bits are flipped than just one bit.) 

For example, suppose we have two original message bits, and we add one parity bit. This can 
be modeled using a directed graphical model as shown in Figure 5.11(a). This graph encodes the 
following joint probability distribution: 


3 
p(x, y) = p(x1)p(w2)p(wsla1, 22) | | p(n) (5.163) 


i=1 


The priors p(x,) and p(z2) are uniform. The conditional term p(x3|x1, 72) is deterministic, and 
computes the parity of (21,22). In particular, we have p(x = 1|x,,x2) = 1 if the total number of 1s 
in the block 21.2 is odd. The likelihood terms p(y;|x;) represent a bit flipping noisy channel model, 
with noise level a = 0.2. 
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Suppose we observe y = (1,0,0). We know that this cannot be what the sender sent, since this 
violates the parity constraint (if zı = 1 then we know x3 = 1). Instead, the 3 posterior modes for x 
are 000 (first bit was flipped), 110 (second bit was flipped), and 101 (third bit was flipped). The only 
other configuration with non-zero support in the posterior is 011, which corresponds to the much less 
likely hypothesis that three bits were flipped (see Figure 5.11(b)). All other hypotheses (001, 010 and 
100) are inconsistent with the deterministic method used to create codewords. (See Section 9.2.3.2 
for further discussion of this point.) 

In practice, we use more complex coding schemes that are more efficient, in the sense that they 
add less redundant bits to the message, but still guarantee that errors can be corrected. For details, 
see Section 9.3.8. 


5.6 The information bottleneck 


In this section, we discuss discriminative models p(y|x) that use a stochastic bottleneck between the 
input a and the output y to prevent overfitting, and improve robustness and calbration. 


5.6.1 Vanilla IB 


We say that z is a representation of a if z is a (possibly stochastic) function of x, and hence can 
be described by the conditional p(z|x). We say that a representation z of x is sufficient for task y 
if y L æ|z, or equivalently, if I(z; y) = I(x; y), i.e., H(y|z) = H (y|a). We say that a representation 
is a minimal sufficient statistic if z is sufficient and there is no other z with smaller I(z; æ) value. 
Thus we would like to find a representation z that maximizes I(z; y) while minimizing I(z; x). That 
is, we would like to optimize the following objective: 


min 8 I(z; x) — I(z; y) (5.164) 


= where 8 > 0, and we optimize wrt the distributions p(z|æ) and p(y|z). This is called the information 
= bottleneck principle [TPB99]. This generalizes the concept of minimal sufficient statistic to take 
— into account that there is a tradeoff between sufficiency and minimality, which is captured by the 
~~ Lagrange multiplier 6 > 0. 


This principle is illustrated in Figure 5.12. We assume Z is a function of X, but is independent 


= of Y, i.e., we assume the graphical model Z + X + Y. This corresponds to the following joint 


— distribution: 


p(x, y, z) = p(z|@)p(ylx)p(a) (5.165) 


38 Thus Z can capture any amount of information about X that it wants, but cannot contain information 
39 that is unique to Y, as illustrated in Figure 5.12a. The optimal representation only captures 
40 information about X that is useful for Y; to prevent us “wasting capacity” and fitting irrelevant 
41 details of the input, Z should also minimize information about X, as shown in Figure 5.12b. 


If all the random variables are discrete, and z = e(x) is a deterministic function of æ, then the 


43 algorithm of [TPB99] can be used to minimize the IB objective in Section 5.6. The objective can 
44 also be solved analytically if all variables are jointly Gaussian [Che+05] (the resulting method can be 
45 viewed as a form of supervised PCA). But in general, it is intractable to solve this problem exactly. 
46 We discuss a tractable approximation in Section 5.6.2. (More details can be found in e.g., [SZ22].) 
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5.6. THE INFORMATION BOTTLENECK 
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Figure 5.12: Information diagrams for information bottleneck. (a) Z can contain any amount of information 
about X (whether it useful for predicting Y or not), but it cannot contain information about Y that is not 
shared with X. (b) The optimal representation for Z maximizes 1(Z,Y) and minimizes 1(Z, X). Used with 
kind permission of Katie Everett. 


5.6.2 Variational IB 


In this section, we derive a variational upper bound on Equation (5.164), leveraging ideas from 
Section 5.3.6. This is called the variational IB or VIB method [Ale+16]. The key trick will be to 
use the non-negativity of the KL divergence to write 


1 dæ p(x) log p(a) > J dæ p(æ) log q(x) (5.166) 


for any distribution q. (Note that both p and q may be conditioned on other variables.) 

To explain the method in more detail, let us define the following notation. Let e(z|æ) = p(z|æ) 
represent the encoder, b(z|y) ~ p(z|y) represent the backwards encoder, d(z|y) ~ p(z|y) represent 
the classifier (decoder), and m(z) ~% p(z) represent the marginal. (Note that we get to choose 
p(z|x), but the other distributions are derived by approximations of the corresponding marginals 
and conditionals of the exact joint p(x, y, z).) Also, let (-) represent expectations wrt the relevant 
terms from the p(x, y, z) joint. 

With this notation, we can derive a lower bound on I(z; y) as follows: 


I(z;y) = [ave p(y, z) log Fe). (5.167) 
= f aydz ply, z) log p(yl2) ~ | dudz ply, 2) los rly) (5.168) 
= f aydz r(z)p(ulz) logplylz) — const (5.169) 
> f ayde ply, 2) log d(y\2) (5.170) 
= (log d(y|z)) (5.171) 


where we exploited the fact that H (p(y)) is a constant that is independent of our representation. 
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Note that we can approximate the expections by sampling from 


(y.2) = | de plap(yler(zle) = f de p(w. y)e(=le) 


This is just the empirical distribution “pushed through” the encoder. 
Similarly, we can derive an upper bound on I(z; x) as follows: 


£) = zdzx p(x, z) | EA 
Leia) = f dade pla, 2) 10s C 


= [dae p(x, z) log p(z|xa) — fe p(z) log p(z) 


ie) 
[oje] 


< [eae p(x, z) log p(z|x) - faz p(z) log m(z) 


e(z\x 


wa 


= f ded p(x, z) log ae 
= (log e(z|a)) — (m(z)) 


Note that we can approximate the expectations by sampling from p(x, z) = p(x)p(z|a). 
Putting it altogether, we get the following upper bound on the IB objective: 


p I(x; z) — Iz; y) < 2 ((log e(z|a)) — (log m(z))) — (log d(y|z)) 
Thus the VIB objective is 


Lvip = p ( Lop (x)e(z|x) [log e(z|x) _ log m(z)]) = Upp (w)e(z|x)d(y|z) [log d(y|z)| 
= —Ep, (#)e(z|x)a(y|z) log d(y|z)] + Epp (æ) [Dex (e(2|2) || m(z))] 


(5.172) 


(5.173) 
(5.174) 
(5.175) 


(5.176) 


(5.177) 


(5.178) 


(5.179) 
(5.180) 


We can now take stochastic gradients of this objective and minimize it (wrt the parameters of 
the encoder, decoder and marginal) using SGD. (We assume the distributions are reparameterizable, 
as discussed in Section 6.5.4.) For the encoder e(z|x), we often use a conditional Gaussian, and 


3, for the decoder d(y|z), we often use a softmax classifier. For the marignal, m(z), we should use 


a flexible model, such as a mixture of Gaussians, since it needs to approximate the aggregated 
posterior p(z) = f dzp(x)e(z|x), which is a mixture of N Gaussians (assuming p(a) is an empirical 


distribution with N samples, and e(z|a) is a Gaussian). 


We illustrate this in Figure 5.13, where we fit the an MLP model to MNIST. We use a 2d bottleneck 


2, layer before passing to the softmax. In panel a, we show the embedding learned by a determinisic 


encoder. We see that each image gets mapped to a point, and there is little overlap between classes, 
or between instances. In panels b-c, we show the embedding learned by a stochastic encoder. Each 
image gets mapped to a Gaussian distribution, we show the mean and the covariance separately. The 
classes are still well separated, but individual instances of a class are no longer distinguishable, since 


such information is not relevant for prediction purposes. 


5.6.3 Conditional entropy bottleneck 


The IB tries to maximize I(Z; Y) while minimizing I(Z; X). We can write this objective as 


min I (a; z) — Al(y; z) 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


(5.181) 


B IO lœ IN ID [o [A [wo N e 


Io IR ls la ie le Ie IE 


IS Is 


5.6. THE INFORMATION BOTTLENECK 


Figure 5.13: 2d embeddings of MNIST digits created by an MLP classifier. (a) Deterministic model. (b-c) 
VIB model, means and covariances. Generated by vib_ demo.ipynb. Used with kind permission of Alex Alemi. 


X y 


Figure 5.14: Conditional entropy bottleneck (CEB) chooses a representation Z that maximizes 1(Z,Y) and 
minimizes I(X, Z|Y). Used with kind permission of Katie Everett. 


for A > 0. However, we see from the information diagram in Figure 5.12b that I(Z; X) contains some 
information that is relevant to Y. A sensible alternative objective is to minimizes the residual mutual 
information, I(X; Z|Y). This gives rise to the following objective: 


min I(x; z|y) — A I(y; z) (5.182) 


for \’ > 0. This is known as the conditional entropy bottleck or CEB [Fis20]. See Figure 5.14 
for an illustration. 

Since I(x; z|y) = I(x; z) —I(y; z), we see that the CEB is equivalent to standard IB with X = A+1. 
However, it is easier to upper bound I (a; z|y) than I(a; z), since we are conditioning on y, which 
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provides information about z. In particular, we have 


I(x; z|y) = I(x; z) — I(y; z) 
= H(z) — H(z|a) — (H(z) — H(z|y)] 
= —H(z|x) — H(z|y) 


= f dzdz ple, 2)logp(z|e) — f dedy p(z, y) log (elu) 
< f dade ple, z)loge(z\x) — | dzdy p(z, y) logb(zly) 
= (loge(z|æ)) — (log b(z|y)) 

Putting it altogether, we get the final CEB objective: 


min £ ((log e(z|a)) — (log b(z|y))) — (log d(y|z)) 


5.183 
5.184 


(5.183) 
(5.184) 
(5.185) 
(5.186) 


(5.187) 
(5.188) 


(5.189) 


Note that it is generally easier to learn the conditional backwards encoder b(z|y) than the 
unconditional marginal m(z). Also, we know that the tightest upper bound occurs when I(a; z|y) = 
I(x; z) — I(y; z) = 0. The corresponding value of 8 corresponds to an optimal representation. By 


contrast, it is not clear how to measure distance from optimality when using IB. 
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6 Optimization 


6.1 Introduction 


In this chapter, we consider solving optimization problems of various forms. Abstractly these can 
all be written as 


6* € argmin £(0) (6.1) 
dco 


where £ : O > R is the objective or loss function, and © is the parameter space we are optimizing 
over. However, this abstraction hides many details, such as whether the problem is constrained 
or unconstrained, discrete or continuous, convex or non-convex, etc. In the prequel to this book, 
[Mur22], we discussed some simple optimization algorithms for some common problems that arise in 
machine learning. In this chapter, we discuss some more advanced methods. For more details on 
optimization, please consult some of the many excellent textbooks, such as [KW19b; BV04; NW06; 
Ber15; Ber16] as well as various review articles, such as [BCN18; Sun+19b; PPS18; Pey20]. 


6.2 Automatic differentiation 


This section was written by Roy Frostig. 


This section is concerned with computing (partial) derivatives of complicated functions in an 
automatic manner. By “complicated” we mean those expressed as a composition of an arbitrary 
number of more basic operations, such as in deep neural networks. This task is known as automatic 
differentiation (AD), or autodiff. AD is an essential component in optimization and deep learning, 
and is also used in several other fields across science and engineering. See e.g. Baydin et al. [Bay+15] 
for a review focused on machine learning and Griewank and Walther [GW08] for a classical textbook. 


6.2.1 Differentiation in functional form 


Before covering automatic differentiation, it is useful to review the mathematics of differentiation. 
We will use a particular functional notation for partial derivatives, rather than the typical one used 
throughout much of this book. We will refer to the latter as the named variable notation for the 
moment. Named variable notation relies on associating function arguments with names. For instance, 
given a function f : R? — R, the partial derivative of f with respect to its first scalar argument, at a 
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point a = (a1, a2), might be written: 


o 

Gi (6.2) 
Oxy xz=a 

This notation is not entirely self-contained. It refers to a name x = (x1, £2), implicit or inferred from 
context, suggesting the argument of f. An alternative expression is: 


fo) 
Bat i az) (6.3) 


where now a; serves both as an argument name (or a symbol in an expression) and as a particular 
evaluation point. Tracking names can become an increasingly complicated endeavor as we compose 
many functions together, each possibly taking several arguments. 

A functional notation instead defines derivatives as operators on functions. If a function has 
multiple arguments, they are identified by position rather than by name, alleviating the need for 
auxiliary variable definitions. Some of the following definitions draw on those in Spivak’s Calculus 
on Manifolds [Spi71], in Sussman and Wisdom’s Functional Differential Geometry [SW13], and 
generally appear more regularly in accounts of differential calculus and geometry. These texts are 
recommended for a more formal treatment, and a more mathematically general view, of the material 
briefly covered in this section. 

Beside notation, we will rely on some basic multivariable calculus concepts. This includes the 
notion of (partial) derivatives, the differential or Jacobian of a function at a point, its role as a linear 
approximation local to the point, and various properties of linear maps, matrices, and transposition. 
We will focus on a finite-dimensional setting and write {e1,..., en} for the standard basis in R”. 


28 Linear and multilinear functions. We use F : R” — R™ to denote a function F : R” — R™ 


29 that is linear, and by Fa] its application to a € R”. Recall that such a linear map corresponds 
30 to a matrix in RX” whose columns are F[e;],...,F' [en]; both interpretations will prove useful. 
2! Conveniently, function composition and matrix multiplication expressions look similar: to compose 
32 two linear maps F and G we can write F o G or, barely abusing notation, consider the matrix FG. 
33 Every linear map F : R” — R”™ has a transpose F : R™ — R”, which is another linear map identified 
34 with transposing the corresponding matrix. 


Repeatedly using the linear arrow symbol, we can denote by: 


T : R” —~.--—0 R” — R” (6.4) 
——— 


k times 


40 a multilinear, or more specifically k-linear, map: 


T:R” x- x R” >R” (6.5) 
——— a 


k times 


45 which corresponds to an array (or tensor) in R™*"*'"*", We denote by T[æ1,..., £k] € R™ the 
46 application of such a k-linear map to vectors £1,..., £p E€ R”. 
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6.2. AUTOMATIC DIFFERENTIATION 


The derivative operator. For an open set U C R” and a differentiable function f : U > R”, 
denote its derivative function: 


af : U > (R” — R”) (6.6) 


or equivalently 0f : U > R™*”. This function maps a point a € U to the Jacobian of all partial 
derivatives evaluated at x. The symbol ð itself denotes the derivative operator, a function mapping 
functions to their derivative functions. When m = 1, the map ðf (æ) recovers the standard gradient 
Vif(x) at any x € U, by considering the matrix view of the former. Indeed, the nabla symbol V 
is sometimes described as an operator as well, such that Vf is a function. When n = m = 1, the 
Jacobian is scalar-valued, and Of is the familiar derivative f’. 

In the expression Of(a)[v], we will sometimes refer to the argument x as the linearization point 
for the Jacobian, and to v as the perturbation. We call the map: 


(x, v) > Of(a)[v] (6.7) 


over linearization points x € U and input perturbations v € R” the Jacobian-vector product 
(J VP). We similarly call its transpose: 


(w,u) > Of (æ) [u] (6.8) 


over linearization points x € U and output perturbations u € R™ the vector-Jacobian product 
(VJP). 

Thinking about maps instead of matrices can help us define higher-order derivatives recursively, as 
we proceed to do below. It separately suggests how the action of a Jacobian is commonly written 
in code. When we consider writing ôf (x) in a program for a fixed x, we often implement it as 
a function that carries out multiplication by the Jacobian matrix, i.e. v + Of(x)|[v], instead of 
explicitly representing it as a matrix of numbers in memory. Going a step further, for that matter, 
we often implement Of as an entire JVP at once, i.e. over any linearization point æ and perturbation 
v. As a toy example with scalars, consider the cosine: 


(x, v) + Ocos(x)v = —v sin(x) (6.9) 


If we express this at once in code, we can, say, avoid computing sin(x) whenever v = 0.1 


Higher-order derivatives. Suppose the function f above remains arbitrarily differentiable over 
its domain U C R”. To take another derivative, we write: 


O° f : U > (R” — R” — R”) (6.10) 


where 0? f(x) is a bilinear map representing all second-order partial derivatives. In named variable 


notation, one might write of Ol to refer to 0? f(a)[e;, ej], for example. 
4OX 5 


1. This example ignores that such an optimization might be done (best) by a compiler. Then again, for more complex 
examples, implementing (a, v) +> Of(a)[v] as a single subroutine can help guide compiler optimizations all the same. 
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The second derivative function 0? f can be treated coherently as the outcome of applying the 
derivative operator twice. That is, it makes sense to say that 0? = ô o ð. This observation extends 
recursively to cover arbitrary higher-order derivatives. For k > 1: 


Ə! f : U > (R" — ... — R? — R”) (6.11) 
ee 
k times 


is such that ô" f(a) is a k-linear map. 

With m = 1, the map 0? f (a) corresponds to the Hessian matrix at any æ € U. Although Jacobians 
and Hessians suffice to make sense of many machine learning techniques, arbitrary higher-order 
derivatives are not hard to come by either (e.g. [Kel+20]). As an example, they appear when writing 
down something as basic as a function’s Taylor series approximation, which we can express with our 
derivative operator as: 


fa +v) = f(a) + Of @)[v) + 50° F@)lv.v] +--+ EOE) o (6.12) 


Multiple inputs. Now consider a function of two arguments: 
g:UxV >R”. (6.13) 


where U C R™ and V C R™. For our purposes, a product domain like U x V mainly serves to 
suggest a convenient partitioning of a function’s input components. It is isomorphic to a subset of 
R™+"2, corresponding to a single-input function. The latter tells us how the derivative functions of 
g ought to look, based on previous definitions, and we will swap between the two views with little 
warning. Multiple inputs tend to arise in the context of computational circuits and programs: many 
functions in code are written to accept multiple arguments, and many basic operations (such as +) 
do the same. 

With multiple inputs, we can denote by O;g the derivative function with respect to the tth 
argument: 


dig: R™ x R”? > (R™ — R™), and (6.14) 
dog: R™ x R™ > (R™ — R”). (6.15) 


Under the matrix view, the function 0,g maps a pair of points x € R”! and y € R”? to the matrix 
of all partial derivatives of g with respect to its first argument, evaluated at (x,y). We take Og with 
no subscript to simply mean the concatenation of O,g and Ogg: 


dg :R™ x R > (R™ x R”? — R”) (6.16) 


~ where, for every linearization point (x,y) € U x V and perturbations & € R™, y € R™: 


44 Alternatively, taking the matrix view: 


Og(x,y) = (ðıg(z, y) Aeog(w,y)) - (6.18) 
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6.2. AUTOMATIC DIFFERENTIATION 


This convention will simplify our chain rule statement below. When nı = ng = m = 1, both 
sub-matrices are scalar, and 0g1(x, y) recovers the partial derivative that might otherwise be written 
in named variable notation as: 


2 g(x,y). (6.19) 


However, the expression 0g; bears a meaning on its own (as a function) whereas the expression go 


may be ambiguous without further context. Again composing operators lets us write higher-order 
derivatives. For instance, 0209(x, y) € R™*™*", and if m = 1, the Hessian of g at (a, y) is: 


O01g(x,y) O,02g(x,y) 
Gee SN (6.20) 


Composition and fan-out. If f = go h for some h : R” > R? and g : R? — R”, then the chain 
rule of calculus observes that: 


Of (x) = Og(h(x)) o h(x) for all x € R” (6.21) 
How does this interact with our notation for multi-argument functions? For one, it can lead us to 


consider expressions with fan-out, where several sub-expressions are functions of the same input. 
For instance, assume two functions a: R” > R™ and b : R” — R™, and that: 


f(x) = g(a(x), b(w)) (6.22) 


for some function g. Abbreviating h(a) = (a(x), b(a)) so that f(x) = g(h(x)), Equations (6.16) 
and (6.21) tell us that: 


Of (x) = ðg(h(x)) o h(x) (6.23) 
= 0,g(a(a), b(a)) o Oa(ax) + Oog(a(a), b(a)) o Ob(a) (6.24) 


Note that + is meant pointwise here. It also follows from the above that if instead: 
f(x,y) = g(a(x), b(y)) (6.25) 
in other words, if we write multiple arguments but exhibit no fan-out, then: 


f(x,y) = A1g(a(x), b(y)) o dala), and (6.26) 
02 f(x,y) = O2g(a(zx), b(y)) o Ob(y) (6.27) 
Composition and fan-out rules for derivatives are what let us break down a complex derivative 
calculation into simpler ones. This is what automatic differentiation techniques rely on when 


processing the sort of elaborate numerical computations that turn up in modern machine learning 
and numerical programming. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO 100 IN IQ o Ie IW IN te 


BR 
= 


BR 
N 


N je j= j= je je IR je 
IS IS le R ls la le lel 


N 
= 


IS IS 18 [8 | 


238 


6.2.2 Differentiating chains, circuits, and programs 


The purpose of automatic differentiation is to compute derivatives of arbitrary functions provided as 
input. Given a function f : U C R” > R™ and a linearization point x € U, AD computes either: 


e the JVP Of(x)[v] for an input perturbation v € R”, or 
e the VJP Of(x)"[u] for an output perturbation u € R”. 


In other words, JVPs and VJPs capture the two essential tasks of AD.” 

Deciding what functions f to handle as input, and how to represent them, is perhaps the most 
load-bearing aspect of this setup. Over what language of functions should we operate? By a language, 
we mean some formal way of describing functions by composing a set of basic primitive operations. For 
primitives, we can think of various differentiable array operations (elementwise arithmetic, reductions, 
contractions, indexing and slicing, concatenation, etc.), but we will largely consider primitives and 
their derivatives as a given, and focus on how elaborately we can compose them. AD becomes 
increasingly challenging with increasingly expressive languages. Considering this, we introduce it in 
stages. 


6.2.2.1 Chain compositions and the chain rule 


To start, take only functions that are chain compositions of basic operations. Chains are a 
convenient class of function representations because derivatives decompose along the same structure 
according to the aptly-named chain rule. 

As a toy example, consider f : R” + R™ composed of three operations in sequence: 


f=coboa (6.28) 


27 By the chain rule, its derivatives are given by 


Of (x) = Oc(b(a(x))) o Ab(a(a)) o Oa(x) (6.29) 


— Now consider the JVP against an input perturbation v € R”: 


Of (x) |v] = Ac(b(a(x))) [b(a(x)) [Aa(x)[e]]] (6.30) 


This expression’s bracketing highlights a right-to-left evaluation order that corresponds to forward- 
mode automatic differentiation. Namely, to carry out this JVP, it makes sense to compute 
prefixes of the original chain: 


xz, a(x), b(a(ax)) (6.31) 


39 alongside the partial JVPs, because each is then immediately used as a subsequent linearization 
40 point, respectively: 


a(x), Ob(a(x)), Oc(b(a(zx))) (6.32) 


-~ Extending this idea to arbitrary chain compositions gives Algorithm 1. 


=" 2. Materalizing the Jacobian as a numerical array, as is commonly required in an optimization context, is a special 


46 case of computing a JVP or VJP against the standard basis vectors in R” or R™ respectively. 
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6.2. AUTOMATIC DIFFERENTIATION 


Algorithm 1: Forward-mode automatic differentiation (JVP) on chains 
input: f:R” — R” as a chain composition f = fro---o fi 
input: linearization point æ € R” and input perturbation v € R” 
To, Vo := T, V 
for t:=1,...,7 do 
i Li := fr(tr-1) 
Ut = Of; (@+-1) [Ve-1] 


7 output: ær, equal to f(x) 
8 output: vr, equal to Of(x)/v] 


aa A ONB 


By contrast, we can transpose Equation (6.29) to consider a VJP against an output perturbation 
u E€ R”: 


Of(a)"[u] = daa)" [Bb(a(x))" [Ac(b(a(e)))" [u] (6.33) 


Transposition reverses the Jacobian maps relative to their order in Equation (6.29), and now the 
bracketed evaluation corresponds to reverse-mode automatic differentiation. To carry out this 
VJP, we can compute the original chain prefixes x, a(x), and b(a(a)) first, and then read them in 
reverse as successive linearization points: 


Ac(b(a(a)))', Ob(a(x))', a(x)" (6.34) 


Extending this idea to arbitrary chain compositions gives Algorithm 2. 


Algorithm 2: Reverse-mode automatic differentiation (VJP) on chains 


1 input: f:R” > R” asa chain composition f = fro---o fi 

2 input: linearization point x € R” and output perturbation u € R™ 
3 £o := T£ 

4 fort:=1,...,7 do 

5 L Tt := fi(ar-1) 

6 ur:=u 


7 for t:=T,...,1 do 

| Ut-1 = Of: (ae—1)" [ui 

9 output: gær, equal to f(x) 

10 output: uo, equal to Of(x)" [ul 


Although chain compositions impose a very specific structure, they already capture some deep 
neural network models, such as multi-layer perceptrons (provided matrix multiplication is a primitive 
operation), as covered in this book’s prequel [Mur22, Ch.13]. 

Reverse-mode AD is faster than forward-mode when the output is scalar valued (as often arises 
in deep learning, where the output is a loss function). However, reverse-mode AD stores all chain 
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(a) (b) 


Figure 6.1: A circuit for a function f over three primitives, and its decomposition into two circuits without 
fan-out. Input nodes are drawn in green. 


prefixes before its backward traversal, so it consumes more memory than forward-mode. There 
are ways to combat this memory requirement in special-case scenarios, such as when the chained 
operations are each reversible [MDA15; Gom+17; KKL20]. One can also trade off memory for 
computation by discarding some prefixes and re-computing them as needed. 


6.2.2.2 From chains to circuits 


When primitives can accept multiple inputs, we can naturally extend chains to circuits—directed 
acyclic graphs over primitive operations, sometimes also called computation graphs. To set up for 


26 this section, we will distinguish between (i) input nodes of a circuit, which symbolize a function’s 
27 arguments, and (ii) primitive nodes, each of which is labeled by a primitive operation. We assume 
28 that input nodes have no incoming edges and (without loss of generality) exactly one outgoing edge 
29 each, and that the graph has exactly one sink node. The overall function of the circuit is composition 
30 of operations from the input nodes to the sink, where the output of each operation is input to others 
31 according to its outgoing edges. 


What made AD work in Section 6.2.2.1 is the fact that derivatives decompose along chains thanks 


33 to the aptly-named chain rule. When moving from chains to directed acyclic graphs, do we need 
34 some sort of “graph rule” in order to decompose our calculation along the circuit’s structure? Circuits 
35 introduce two new features: fan-in and fan-out. In graphical terms, fan-in simply refers to multiple 
36 edges incoming to a node, and fan-out refers to multiple edges outgoing. 


What do these mean in functional terms? Fan-in happens when a primitive operation accepts 


38 multiple arguments. We observed in Section 6.2.1 that multiple arguments can be treated as one, and 
39 how the chain rule then applies. Fan-out requires slightly more care, specifically for reverse-mode 


40 differentiation. 


The gist of an answer can be illustrated with a small example. Consider the circuit in Figure 6.1la. 


42 The operation a precedes b and c topologically, with an outgoing edge to each of both. We can cut a 
43 away from {b,c} to produce two new circuits, shown in Figure 6.1b. The first corresponds to a and 
44 the second corresponds to the remaining computation, given by: 


fiec (a1, £2) = c(a1, b(a2)) . (6.35) 
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6.2. AUTOMATIC DIFFERENTIATION 


We can recover the complete function f from a and fte c} with the help of a function dup given by: 


dup(x) = (x, x) = G x (6.36) 
so that f can be written as a chain composition: 


f = fie c} © dup oa. (6.37) 


The circuit for f{b,e} contains no fan-out, and composition rules such as Equation (6.25) tell us its 
derivatives in terms of b, c, and their derivatives, all via the chain rule. Meanwhile, the chain rule 
applied to Equation (6.37) says that: 


Of (#) = Of,v,c} (dup(a(x))) o Odup(a(#)) o a(x) (6.38) 
= afte} (a(x), a(x) © (7) o dala). (6.39) 


The above expression suggests calculating a JVP of f by right-to-left evaluation. It is similar to 
the JVP calculation suggested by Equation (6.30), but with a duplication operation (I I j in the 
middle that arises from the Jacobian of dup. 

Transposing the derivative of f at æ: 


f(a)" =da(a)’ o (I T) o 3f lalæ), a(x). (6.40) 


Considering right-to-left evaluation, this too is similar to the VJP calculation suggested by Equa- 
tion (6.33), but with a summation operation (I I ) in the middle that arises from the transposed 
Jacobian of dup. The lesson of using dup in this small example is that, more generally, in order to 
handle fan-out in reverse mode AD, we can process operations in topological order—first forward 
and then in reverse—and then sum partial VJPs along multiple outgoing edges. 


Algorithm 3: Foward-mode circuit differentiation (JVP) 


1 input: f:R” —> R” composing fı,..., fr in topological order, where fı is identity 
2 input: linearization point x € R” and perturbation v € R” 

3 T1, U1 = T, U 

4 for t:= 2,...,T do 

5 let [q1,.--,@r] = Pa(t) 

6 v= fi(aq,,---,Lq,) 

7 Ut = et 3i ftlLq s- - -Lan ) [Vai 

8 output: gær, equal to f(x) 


© 


output: vr, equal to Of (x)[v] 


Algorithms 3 and 4 give a complete description of forward- and reverse-mode differentiation on 
circuits. For brevity they assume a single argument to the entire circuit function. Nodes are indexed 
1,...,7. The first is the input node, and the remaining T — 1 are labeled by their operation fo,..., fr. 
We take fı to be the identity. For each t, if f takes k arguments, let Pa(t) be the ordered list of 
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Algorithm 4: Reverse-mode circuit differentiation (VJP) 
input: f :R” — R” composing fı,..., fr in topological order, where fı, fr are identity 
input: linearization point x € R” and perturbation u € R™ 
T1 := £ 
for t := 2,...,T do 

let [q1,---; qr] = Pa(t) 

| Li = fr(@q,---,Lq,) 

7T U(T-1)>T ‘=U 

8 for t:=7T-—1,...,2 do 

9 let [q1,---;@r] = Pa(t) 

u, := Z eccnt) Ute 


aa A wN 


Ia = Tai s— 
1411 | Ugit 5 Oift(Lq,---,Lq,) Ue fori=1,...,r 


~12 output: ær, equal to f(x) 
*°13 output: ui, equal to ð f(x)" u 


k indices of its parent nodes (possibly containing duplicates, due to fan-out), and let Ch(t) be the 


21 indices of its children (again possibly duplicate). Algorithm 4 takes a few more conventions: that fr 


is the identity, that node T has T — 1 as its only parent, and that the child of node 1 is node 2. 
Fan-out is a feature of graphs, but arguably not an essential feature of functions. One can always 


24 remove all fan-out from a circuit representation by duplicating nodes. Our interest in fan-out is 


precisely to avoid this, allowing for an efficient representation and, in turn, efficient memory use in 


26 Algorithms 3 and 4. 


Reverse-mode AD on circuits has appeared under various names and formulations over the years. 


28 The algorithm is precisely the backpropagation algorithm in neural networks, a term introduced 
29 in the 1980s [RHW86b; RHW86a], and has separately come up in the context of control theory 
30 and sensitivity, as summarized in historical notes by Goodfellow, Bengio, and Courville [GBC16, 
31 Section 6.6]. 


33 6.2.2.3 From circuits to programs 


Graphs are useful for introducing AD algorithms, and they might align well enough with neural 
network applications. But computer scientists have spent decades formalizing and studying various 
“languages for expressing functions compositionally.” Simply put, this is what programming languages 
are for! Can we automatically differentiate numerical functions expressed in, say, Python, Haskell, 
or some variant of the lambda calculus? These offer a far more widespread—and intuitively more 
expressive—way to describe an input function.’ 

In the previous sections, our approach to AD became more complex as we allowed for more 


49 complex graph structure. Something similar happens when we introduce grammatical constructs in a 


programming language. How do we adapt AD to handle a language with loops, conditionals, and 


— 3. In Python, what the language calls a “function” does not always describe a pure function of the arguments listed 


=" in its syntactic definition; its behavior may rely on side effects or global state, as allowed by the language. Here, we 


46 specifically mean a Python function that is pure and functional. JAX’s documentation details this restriction [Bra+18]. 
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6.3. STOCHASTIC GRADIENT DESCENT 


recursive calls? What about parallel programming constructs? We have partial answers to questions 
like these today, although they invite a deeper dive into language details such as type systems and 
implementation concerns [Yu+18; Inn20; Pas+21b]. 

One example language construct that we already know how to handle, due to Section 6.2.2.2, 
is a standard let expression. In languages with a means of name or variable binding, multiple 
appearances of the same variable are analogous to fan-out in a circuit. Figure 6.la corresponds to a 
function f that we could write in a functional language as: 


f(x) = 
let ax = a(x) 
in c(ax, b(ax)) 


in which ax indeed appears twice after it is bound. 

Understanding the interaction between language capacity and automatic differentiability is an 
ongoing topic of computer science research [PS08a; AP19; Vyt+19; BMP19; MP21]. In the meantime, 
functional languages have proven quite effective in recent AD systems, both widely-used and experi- 
mental. Systems such as JAX, Dex, and others are designed around pure functional programming 
models, and internally rely on functional program representations for differentiation [Mac+15; BPS16; 
Sha+19; FJL18; Bra+18; Mac+19; Dex; Fro+21; Pas+21a]. 


6.3 Stochastic gradient descent 


In this section, we consider optimizers for unconstrained differentiable objectives. We consider 
gradient-based solvers which perform iterative updates of the following form 


9141 = Ot — Cige (6.41) 


where g; = VL(0+) is the gradient of the loss, and C; is an optional conditioning matrix. 

If we set C; = I, the method is known as steepest descent or gradient descent. If we set 
C: = H7 ', where H; = V?L(6;) is the Hessian, we get Newton’s method. There are many variants 
of Newton’s method that are either more numerically stable, or more computationally efficient, or 
both. 

In many problems, we cannot compute the exact gradient, either because the loss is stochastic 
(e.g., due to random factors in the environment), or because we approximate the loss by randomly 
subsampling the data. In such cases, we can modify the update in Equation (6.41) to use an unbiased 
approximation of the gradient. For example, suppose the loss is a finite-sum objective from a 
supervised Pe problem: 


N 
SN f(@nj®)) = wh (6.42) 


We can wa the gradient using a minibatch B; of size B = |B;| as follows: 


= VL(O) = eS VL (4) (6.43) 
Bh, 


Since the minibatches are randomly sampled, this is a stochastic, but unbiased, estimate of the 
gradient. If we insert ĝ; into Equation (6.41), the method is called stochastic gradient descent 
or SGD. 
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(a) (b) 


Figure 6.2: Changing the mean of a Gaussian by a fired amount (from solid to dotted curve) can have more 
impact when the (shared) variance is small (as in a) compared to when the variance is large (as in b). Hence 
the impact (in terms of prediction accuracy) of a change to p depends on where the optimizer is in (u, 0) 
space. From Figure 3 of [Hon+10], reproduced from [Val00]. Used with kind permission of Antti Honkela. 


6.4 Natural gradient descent 


In this section, we discuss natural gradient descent (NGD) [Ama98], which is a second order 
method for optimizing the parameters of (conditional) probability distributions pg(y|a). The key 
idea is to compute parameter updates by measuring distances between the induced distributions, 
rather than comparing parameter values directly. 

For example, consider comparing two Gaussians, pg = p(y|u, o) and pe = p(y|y’,o’). The (squared) 
Euclidean distance between the parameter vectors decomposes as ||@ — 0’||? = (u — w)? + (o — 0’)?. 
However, the predictive distribution has the form exp(—54z(y — j4)”), so changes in u need to be 
measured relative to ø. This is illustrated in Figure 6.2(a-b), which shows two univariate Gaussian 
distributions (dotted and solid lines) whose means differ by 6. In Figure 6.2(a), they share the same 
small variance o°, whereas in Figure 6.2(b), they share the same large variance. It is clear that 
the value of ô matters much more (in terms of the effect on the distribution) when the variance is 
small. Thus we see that the two parameters interact with each other, which the Euclidean distance 
cannot capture. This problem gets much worse when we consider more complex models, such as deep 
neural networks. By modeling such correlations, NGD can converge much faster than other gradient 
methods. 


6.4.1 Defining the natural gradient 


28 The key to NGD is to measure the notion of distance between two probability distributions in terms 


of the KL divergence. As we show in Section 2.4.4, this can be appproximated in terms of the Fisher 


40 information matrix (FIM). In particular, for any given input x, we have 


1 
Dux (polylæ) || po+s(yle)) ~ 56" Fad (6.44) 
where F, is the FIM 


F2(9) = -Epo (yla) [V" log po(ylæ)] = Epo(yie) [(V log po(yla))(V log pe (ylæ))"] (6.45) 
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We can compute the average KL between the current and updated distributions using 16'F6, where 
F is the averaged FIM: 


F(@) = Lp (a) [Fz (0)] (6.46) 


NGD uses the inverse FIM as a preconditioning matrix, i.e., we perform updates of the following 
form: 


9:41 = 9: — mF (81) “91 (6.47) 
The term 
F-!g, = F-'VL(0,) = VL(8:) (6.48) 


is called the natural gradient. 


6.4.2 Interpretations of NGD 
6.4.2.1 NGD as a trust region method 


In Supplementary Section 6.1.3.1 we show that we can interpret standard gradient descent as 
optimizing a linear approximation to the objective subject to a penalty on the £2 norm of the change 
in parameters, i.e., if 0,11 = 6; + ô, then we optimize 

Mi (5) = L(8:) + g1 8 + nlll (6.49) 


Now let us replace the squared distance with the squared FIM-based distance, ||6||?, = ô' Fô. This 
is equivalent to squared Euclidean distance in the whitened coordinate system ¢ = F206, since 


ldi- Allè = |F? (0: + ô) — F?6,||3 = ||F?5||3 = ||5||7 (6.50) 
The new objective becomes 

M,(6) = £L(0;) +g) 6 + n6'F6 (6.51) 
Solving V5M;(d) = 0 gives the update 

6, = -nF g (6.52) 


This is the same as the natural gradient direction. Thus we can view NGD as a trust region method, 
where we use a first-order approximation to the objective, and use FIM-distance in the constraint. 

In the above derivation, we assumed F was a constant matrix. Im most problems, it will change at 
each point in space, since we are optimizing in a curved space known as a Riemannian manifold. 
For certain models, we can compute the FIM efficiently, allowing us to capture curvature information, 
even though we use a first-order approximation to the objective. 


6.4.2.2 NGD as a Gauss-Newton method 


If p(y|x, @) is an exponential family distribution with natural parameters computed by n = f(x, 8), 
then one can show [Hes00; PB14] that NGD is identical to the generalized Gauss-Newton (GGN) 
method (Section 17.3.2). Furthermore, in the online setting, these methods are equivalent to 
performing sequential Bayesian inference using the extended Kalman filter, as shown in [OII18]. 
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6.4.3 Benefits of NGD 


The use of the FIM as a preconditioning matrix, rather than the Hessian, has two advantages. First, 
F is always positive definite, whereas H can have negative eigenvalues at saddle points, which are 
prevalent in high dimensional spaces. Second, it is easy to approximate F online from minibatches, 
since it is an expectation (wrt the empirical distribution) of outer products of gradient vectors. This 
is in contrast to Hessian-based methods [Byr+16; Liu-+18a], which are much more sensitive to noise 
introduced by the minibatch approximation. 

In addition, the connection with trust region optimization makes it clear that NGD updates 
parameters in a way that matter most for prediction, which allows the method to take larger steps in 
uninformative regions of parameter space, which can help avoid getting stuck on plateaus. This can 
also help with issues that arise when the parameters are highly correlated. 


For example, consider a 2d Gaussian with an unusual, highly coupled parameterization, proposed 

in [SD12]: 
1 1 1,]\? 1 1, 1V? 

p(x; 0) = = exp | 5 G for + D) 5 (e 5% | (6.53) 
The objective is the cross entropy loss: 

L(0) = -E+ (œ) [log p(x; 8)] (6.54) 
The gradient of this objective is given by 

= Ep- (e) [3(01 — [301 + 262]) + 2(a9 md) 
VoL(0 p* (æ) I 3 3 3 6.55 
EAP) ( ips (æ) [5 (#1 — [31 + 302))] ane) 
Suppose that p* (x) = p(x; [0,0]). Then the Fisher matrix is a constant matrix, given by 
24 1 
pa G Ta i) (6.56) 
32 


Figure 6.3 compares steepest descent in 0 space with the natural gradient method, which is 


33 equivalent to steepest descent in @ space. Both methods start at 9 = (1,—1). The global optimum is 
34 at 0 = (0,0). We see that the NG method (blue dots) converges much faster to this optimum and 
2° takes the shortest path, whereas steepest descent takes a very circuitous route. We also see that 
°° the gradient field in the whitened parameter space is more “spherical”, which makes descent much 
2° simpler and faster. 


Finally, note that since NGD is invariant to how we parameterize the distribution, we will get the 


2° same results even for a standard parameterization of the Gaussian. This is particularly useful if our 
40 probability model is more complex, such as a DNN (see e.g., [SSE18]). 


~ 6.4.4 Approximating the natural gradient 


44 The main drawback of NGD is the computational cost of computing (the inverse of) the Fisher 
45 Information Matrix (FIM). To speed this up, several methods make assumptions about the form 
46 of F, so it can be inverted efficiently. For example, [LeC+-98] uses a diagonal approximation for 
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Parameter trajectories KL divergence vs. update step 


on Steepest descent + g x 
NG descent X g 1074 
a0 
a —0.5 © 
*d 1077? 
= 
—1.0 Se 
0.0 0.5 1.0 10! 10° 10° 
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Figure 6.3: Illustration of the benefits of natural gradient vs steepest descent on a 2d problem. (a) Trajectories 
of the two methods in parameter space (red = steepest descent, blue = NG). They both start in the bottom 
right, at (1,—1). (b) Objective vs number of iterations. (c) Gradient field in the @ parameter space. (d) 


Gradient field in the whitened ọ = F20 parameter space used by NG. Generated by nat_grad_ demo.ipynb. 


neural net training; |[RMB08] uses a low-rank plus block diagonal approximation; and [GS15] assumes 
the covariance of the gradients can be modeled by a directed Gaussian graphical model with low 
treewidth (i.e., the Cholesky factorization of F is sparse). 

[MG15] propose the KFAC method, which stands for “Kronecker-Factored approximate curvature”; 
this approximates the FIM of a DNN as a block diagonal matrix, where each block is a Kronecker 
product of two small matrices. This method has shown good results on supervised learning of 
neural nets [GM16; BGM17; Geo+18; Osa+19b] as well as reinforcement learning of neural policy 
networks [Wu+17]. The KFAC approximation can be justified using the mean field analysis of 
[AKO18]. In addition, [ZMG19] prove that KFAC will converge to the global optimum of a DNN if 
it is overparameterized (i.e., acts like an interpolator). 

A simpler approach is to approximate the FIM by replacing the model’s distribution with the em- 
pirical distribution. In particular, define pp(#, y) = + sale bx, (#) dy, (Y), ppls) = ss bx, (£) 
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and pe(x,y) = pp(x)p(y|x,@). Then we can compute the empirical Fisher [Mar16] as follows: 


F = E,,(a,y) [V log p(ylx, 0)V log p(y|x, 0)" ] (6.57) 
~ Epp(a,y) [V log plylæ, 8) V log p(y|æ, 0)" ] (6.58) 
1 
= jpj, 2. Vlogplyle, 8)V log pyle, 0)" (6.59) 
(x,y)ED 
This approximation is widely used, since it is simple to compute. In particular, we can compute a 
diagonal approximation using the squared gradient vector. (This is similar to ADAGRAD, but only 


uses the current gradient instead of a moving average of gradients; the latter is a better approach 
when performing stochastic optimization.) 

Unfortunately, the empirical Fisher does not work as well as the true Fisher [KBH19; Tho+19]. 
To see why, note that when we reach a flat part of parameter space where the gradient vector goes 
to zero, the empirical Fisher will become singular, and hence the algorithm will get stuck on this 
plateau. However, the true Fisher takes expectations over the outputs, i.e., it marginalizes out y. 
This will allow it to detect small changes in the output if we change the parameters. This is why the 
natural gradient method can “escape” plateaus better than standard gradient methods. 

An alternative strategy is to use exact computation of F, but solve for F~'g approximately 
using truncated conjugate gradient (CG) methods, where each CG step uses efficient methods for 
Hessian-vector products [Pea94]. This is called Hessian free optimization [Mar10a]. However, 
this approach can be slow, since it may take many CG iterations to compute a single parameter 
update. 


6.4.5 Natural gradients for the exponential family 


In this section, we asssume £ is an expected loss of the following form: 


L(a) = E0) [C(8)] (6.60) 


30 where qu (0) is an exponential family distribution with moment parameters u. This is the basis of 
31 variational optimization (discussed in Section 6.7.3) and natural evolutionary strategies (discussed in 


Section 6.9.6). 
It turns out the gradient wrt the moment parameters is the same as the natural gradient wrt the 
natural parameters A. This follows from the chain rule: 


ay g Llu) = FAV nL (eH) (6.61) 


where L(u) = L(A(u)), and where we used Equation (2.215) to write 


F(A) = Vana) = V3A(A) (6.62) 
41 Hence 
VaL(A) = F(A) VAL(A) = Va L(y) (6.63) 


It remains to compute the (regular) gradient wrt the moment parameters. The details on how to 


45 do this will depend on the form of the q and the form of L(A). We discuss some approaches to this 
46 problem below. 
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6.4. NATURAL GRADIENT DESCENT 


6.4.5.1 Analytic computation for the Gaussian case 


In this section, we assume that q(@) = N(0|m, V). We now show how to compute the relevant 
gradients analytically. 
Following Section 2.3.2.5, the natural parameters of q are 


1 
AY = Vim, AP = =v (6.64) 
and the moment parameters are 


For simplicity, we derive the result for the scalar case. Let m = p and v = p®?) — (u@)?. By using 
the chain rule, the gradient wrt the moment parameters are 


OL OL Om ƏL Ov oL oL 


A omo wo om m” (6.66) 


ðL ƏL ðm L dW L 

Op?) — Om Op) vôu v 

It remains to compute the derivatives wrt m and v. If @ ~ N(m,V), then from Bonnet’s 
theorem [Bon64] we have 


5 (¢(6)] = E Eu) (6.68) 


(6.67) 


And from Price’s theorem [Pri58] we have 
o o? 

3 [£(0)] = cE | —— 
JELO) = «|e 


70) (6.69) 


where cij = 4 is i= j and cij; = 1 otherwise. (See gradient _expected_value_gaussian.ipynb for a 


“proof by example” of these claims.) 
In the multivariate case, the result is as follows [OA09; KR21a]: 


Vp iq(0) [e(0)| = Vm q(8) [e(0)| m 2Vv Uq(0) [e(0)| m (6.70) 
= Exo) [Vol(8)] — Eqo) [V3e(8)] m (6.71) 
V po Esco) LO) = VvEq(o) IEO) (6.72) 
= 5Eqo) [V3E(6)] (6.73) 


2 


Thus we see that the natural gradients rely on both the gradient and Hessian of the loss function 
£(@). We will see applications of this result in Section 6.7.2.2. 
6.4.5.2 Stochastic approximation for the general case 


In general, it can be hard to analytically compute the natural gradient. However, we can compute a 
Monte Carlo approximation. To see this, let us assume £ is an expected loss of the following form: 


L£(m) = Eg, (a) ECO) (6.74) 
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From Equation (6.63) the natural gradient is given by 
VpLlu) = F(A) VAL) (6.75) 
For exponential family distributions, both of these terms on the RHS can be written as expectations, 
and hence can be approximated by Monte Carlo, as noted by [KL17a]. To see this, note that 
F(A) = Vau(A) = VaEa (0) [7 (9)] (6.76) 
VaL(A) = VaE aa o) [€(8) (6.77) 


If q is reparameterizable, we can apply the reparameterization trick (Section 6.5.4) to push the 
gradient inside the expectation operator. This lets us sample 0 from q, compute the gradients, and 


average; we can then pass the resulting stochastic gradients to SGD. 
6.4.5.3 Natural gradient of the entropy function 
In this section, we discuss how to compute the natural gradient of the entropy of an exponential 
family distribution, which is useful when performing variational inference (Chapter 10). The natural 
gradient is given by 
Va H(A) = -V, Eq, (6) [log q(8)] (6.78) 
where, from Equation (2.143), we have 
log q(@) = log h(@) + T(@)'A — A(A) (6.79) 
Since E[7(@)] = u, we have 
V Eq, (6) [log a(8)] = V pE) log h(8)] + Vue" ACH) — Yp ACA) (6.80) 
= where h(@) is the base measure. Since A is a function of p, we have 
Vu A=At(VpA) W=A+4(FX'VA) WHALEY py (6.81) 
and since 4 = V)A(A) we have 
VACA) = FX'V)A(A) = Fx" (6.82) 
35 Hence 
—V E, (0) [log q(@)] = -V pEi) [log h(@)] — A (6.83) 


— Tf we assume that h(@) = const, as is often the case, we get 


Va H(A) =—-A (6.84) 


— 6.5 Gradients of stochastic functions 


= In this section, we discuss how to compute the gradient of stochastic functions of the form 


L(Y) = Ey, (2) (E(w, 2)] (6.85) 
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6.5. GRADIENTS OF STOCHASTIC FUNCTIONS 


6.5.1 Minibatch approximation to finite-sum objectives 


In the simplest case, ¢y(z) does not depend on w. In this case, we can push gradients inside the 
expectation operator, VL(q) = E[V£(%, z)] and then use Monte Carlo sampling for z to approximate 
the gradient. 

For example, consider the empirical risk minimization (ERM) problem of minimizing 


rA 
L) = 5 XO ep, Zn) (6.86) 


n=1 


where Zn = (£n, Yn) and 


LP, Zn) = €(h(Lni Y), Yn) (6.87) 


is the per-example loss, where h is a prediction function. This kind of objective is called a finite 
sum objective. 

Now consider trying to minimize this objective. If, at each iteration, we evaluate the objective (and 
its gradient) using all Np datapoints, the method is called batch optimization. However, this can 
be very slow if the dataset is large. Fortunately, we can reformulate it as a stochastic optimization 
problem, which will be faster to solve. To do this, note that Equation (6.86) can be written as an 
expectation wrt the empirical distribution: 


L) = Ezwpn le, 2)] (6.88) 


Since the distribution is independent of the parameters, we can easily use Monte Carlo sampling to 
approximate the objective and its gradient. In particular, we will sample a minibatch of B = |B| 
datapoints from the full set D at each iteration. More precisely, we have 


Lp) = 5 Y Uleni Y), yn) (6.89) 
nEB 
VL(p) & E YO Ve hens Y), yn) (6.90) 
nEB 


These noisy gradients can then be passed to SGD, which is robust to noisy gradients (see Section 6.3). 


6.5.2 Optimizing parameters of a distribution 


Now suppose the stochasticity depends on the parameters we are optimizing. For example, z could 
be an action sampled from a stochastic policy gy, as in RL (Section 35.3.2). In this case, the gradient 
is given by 


VyEayte) (2) = Vy f Us 2)ap(2)dz (6.91) 


= Fi [Vye 2)] ay(2)dz + i; (wb, z) [Vya (2)] dz (6.92) 
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The first term can be approximated by Monte Carlo sampling: 


S 
J [Vpl(ap, z)]qy(z)dz ~ 2 5 Vall, Zs) (6.93) 


where Zs ~ qy. Note that if £() is independent of p, this term vanishes. 
Now consider the second term, that takes the gradients of the distribution itself: 


14 | tb.) Vya) dz (6.94) 


We can no longer use vanilla Monte Carlo sampling to approximate this integral. However, there are 
various other ways to approximate this (see [Moh+19] for an extensive review). We briefly describe 
the two main methods in Section 6.5.3 and Section 6.5.4. 


6.5.3 Score function estimator (likelihood ratio trick) 


The simplest way to approximate Equation (6.94) is to exploit the log derivative trick, which is 
the following identity: 


Vady (z) = dy (2)V y log qy (z) (6.95) 
With this, we can rewrite Equation (6.94) as follows: 


r= | b.2)lael2)Vw log qy (z)]dz = E44 (2) L(Y, z)V y log ap (2)] (6.96) 


This is called the score function estimator or SFE [Ful5]. (The term “score function” refers to 
the gradient of a log probability distribution, as explained in Section 2.4.1.) It is also called the 
likelihood ratio gradient estimator, or the REINFORCE estimator (the reason for this latter 


28 name is explained in Section 35.3.2). We can now easily approximate this with Monte Carlo: 


S 
1x $ YO Mab, 26) Vy log ay(2s) (6.97) 


s=l 


33 We only require that the sampling distribution is differentiable, not the objective (p, z) itself. This 


allows the method to be used for blackbox stochastic optimization problems, such as variational 
optimization (Section 6.7.3), black-box variational inference (Section 10.3.2), reinforcement learning 
(Section 35.3.2), etc. 


6.5.3.1 Control variates 


= The score function estimate can have high variance. One way to reduce this is to use control 
= variates, in which we replace (a, z) with 


(ap, z) = Lp, z) — c (bl, z) — E [b(, z)]) (6.98) 


where b(p, z) is a baseline function that is correlated with ¢(w, z), and c > 0 is a coefficient. Since 


) ecw, z)| = E [é(a), z)], we can use Ê to compute unbiased gradient estimates of £. The advantage 


46 is that this new estimate can result in lower variance, as we show in Section 11.6.3. 
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6.5. GRADIENTS OF STOCHASTIC FUNCTIONS 


6.5.3.2 Rao-Blackwellisation 


Suppose qy (z) is a discrete distribution. In this case, our objective becomes L(w) = >>, KY, z) dy (z). 
For simplicity, let us assume ¢(, z) = (z). 

We can now easily compute gradients using Vy L(Y) = >>, &(z)Vpaqy(z). Of course, if z can take 
on exponentially many values (e.g., we are optimizing over the space of strings), this expression 
is intractable. However, suppose we can partition this sum into two sets, a small set Sı of high 
probability values and a large set S2 of all other values. Then we can enumerate over S1 and use the 
score function estimator for S9: 


Vyl(p) = X lz) Vpay(2) + Egy (zleesa) [¢(2) Vo log gy (2)] (6.99) 
ZzES, 


To compute the second expectation, we can use rejection sampling applied to samples from qy(z). 
This procedure is a form of Rao-Blackwellisation as shown in [Liu+-19b], and reduces the variance 
compared to standard SFE (see Section 11.6.2 for details on Rao-Blackwellisation). 


6.5.4 Reparameterization trick 


The score function estimator can have high variance, even when using a control variate. In this 
section, we derive a lower variance estimator, which can be applied if ¢(a), z) is differentiable wrt z. 
We additionally require that we can compute a sample from qy(z) by first sampling € from some 
noise distribution gg which is independent of yw, and then transforming to z using a deterministic and 
differentiable function z = r(ọ%, e€). For example, instead of sampling z ~ N (u, o°), we can sample 
e ~ N(0,1) and compute 


z=r(w,e)=pt+oe (6.100) 
where w = (44,0). This allows us to rewrite our stochastic objective as follows: 
L(Y) = Egy (zy KL, 2)] = Ego e) EL, rah, €))] (6.101) 


Since qo(€) is independent of p, we can push the gradient operator inside the expectation, which we 
can approximate with Monte Carlo: 


S 
VyL() = Eolo [Vullsr(w,€))] 5 >. Vye, r,e) (6.102) 


where €s ~ qo. This is called the reparameterization gradient or the pathwise derivative 
[Gla03; Ful5; KW14; RMW14a; TLG14; JO18; FMM18], and is widely used in variational inference 
(Section 10.3.3), and when fitting VAE models (Section 21.2). 


6.5.4.1 Example 


As a simple example, suppose we define some arbitrary function, such as ¢(z) = 2? — 3z, and then 
define its expected value as L(~) = Ey(zip,v) E(2)], where Y = (u,v) and v = 0”. Suppose we want 
to compute 


0 o 
a Nie [(z)]] (6.103) 


VL) = | 
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Since the Gaussian distribution is reparameterizable, we can sample z ~ N (z|u, v), and then use 
automatic differentiation to compute each of these gradient terms, and then average. 

However, in the special case of Gaussian distributions, we can also compute the gradient vector 
directly. In particular, in Section 6.4.5.1 we present Bonnet’s theorem, which states that 


x b [(z)] =E ee) (6.104) 


Similarly, Price’s theorem states that 


2 D [¢(z)] = 0.5E Fao) (6.105) 


In gradient expected_value_gaussian.ipynb we show that these two methods are numerically 
equivalent, as theory suggests. 


6.5.4.2 Total derivative 


To compute the gradient term inside the expectation in Equation (6.102) we need to use the total 
derivative, since the function £ depends on w directly and via the noise sample. Recall that, for 
a function of the form f(q1,...,Va,,21(),---, Za, ()), the total derivative wrt ~; is given by the 
chain rule as follows: 


aL al 3 ƏL Oz; 


= 6.106 
Dh T OW 24 De OW, a 
and hence 
Vel, z) = Vll, z) + J'V L(Y, z) (6.107) 
33 where J = ie is the d, x dy Jacobian matrix of the noise transformation: 
Oz1 Sct Oz1 
Ov, Oba, 
J= : oy (6.108) 
Oza, ae Oza, 
Oa, Oba, 
Hence we can compute the gradient using 
Vul(p) = Egte) (Voll, z) +I, e) Veep, r(, €))] (6.109) 


45 We leverage this decomposition in Section 10.3.3.1, where we derive a lower variance gradient estimator 
46 in the special case of variational inference. 
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1 

2 6.5.5 The delta method 

j The delta method [Hoe12] approximates the expectation of a function of a random variable by the 
~ expectation of the function’s Taylor expansion. For example, suppose 0 ~ q with mean E,g) [0] = m, 
; and we use a first order expansion. Then we have 

7 540) [f(8)] = Ego) [f(m) + (0 — m)" Vo f(O)lo=m] = f(m) (6.110) 
8 

g Now let f(@) = VeL(@). Then we have 

F zaco) [VoL(0)]  VoL(O)\a=m (6.111) 
12 This is called the first-order delta method. 

13 Now let f(0) = V3 L(0). Then we have 

14 

T 20) [Va L(0)] ~ VELO) lo-m (6.112) 
16 This is called the second-order delta method. 

17 

T 6.5.6 Gumbel softmax trick 

20 When working with discrete variables, we cannot use the reparameterization trick. However, we can 
21 often relax the discrete variables to continuous ones in a way which allows the trick to be used, as we 
22 explain below. 

23 Consider a one-hot vector d with K bits, so dp € {0,1} and S dy, = 1. This can be used 
24 to represent a K-ary categorical variable d. Let P(d) = Cat(d|r), where 7 = P(d, = 1), so 
25 0 < a, <1. Alternatively we can parameterize the distribution in terms of (a1,...,a,%), where 
26 mk = an/ (So, ax). We will denote this by d ~ Cat(d|a). 

27 We can sample a one-hot vector d from this distribution by computing 

28 

29 d= OHCnOU ane mes ck + log ax]) (6.113) 
30 

31 where ex ~ Gumbel(0, 1) is sampled from the Gumbel distribution [Gum54]. We can draw such 
32 samples by first sampling ug ~ Unif(0,1) and then computing e, = —log(—log(u,)). This is 
33 called the Gumbel-Max trick [MTM14], and gives us a reparameterizable representation for the 
34 categorical distribution. 

35 Unfortunately, the derivative of the argmax is 0 everywhere except at the boundary of transitions 
36 from one label to another, where the derivative is undefined. However, suppose we replace the argmax 
37 with a softmax, and replace the discrete one-hot vector d with a continuous relaxation x € AKL 
33 where AT! = {æ € RË : a, € [0, 1,54 £k = 1} is the K-dimensional simplex. Then we can 
39 write 

40 l A A 

a ee ee (6.114) 
= Xp CXP((log ag + €x)/T) 

43 where rt > 0 isa temperature parameter. This is called the Gumbel-Softmax distribution [JGP17| 
44 or the concrete distribution [MMT17|. This smoothly approaches the discrete distribution as 
45 7 — 0, as illustrated in Figure 6.4. 

46 We can now replace f(d) with f(x), which allows us to take reparameterized gradients wrt x. 
aT 
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Categorical T = 10.0 


Lt Le 
oe ee ee oe A 


category 


expectation 


sample 


Figure 6.4: Illustration of the Gumbel-Softmaz (concrete) distribution with K = 7 states at different 
temperatures T. The top row shows E |z], and the bottom row shows samples z ~ GumbelSoftmax(a,7). The 
left column shows a discrete (categorical) distribution, which always produces one-hot samples. From Figure 1 
of [JGP17]. Used with kind permission of Ben Poole. 


6.5.7 Stochastic computation graphs 


We can represent an arbitrary function containing both deterministic and stochastic components 
as a stochastic computation graph. We can then generalize the AD algorithm (Section 6.2) to 
leverage score function estimation (Section 6.5.3) and reparameterization (Section 6.5.4) to compute 
Monte Carlo gradients for complex nested functions. For details, see [Sch+15a; Gaj+19]. 


6.5.8 Straight-through estimator 


In this section, we discuss how to approximate the gradient of a quantized version of a signal. For 
example, suppose we have the following thresholding function, that binarizes its output: 


1 ifa>0 
o= f2<0 (6.115) 


= This does not have a well-defined gradient. However, we can use the straight-through estimator 
2 proposed in [Ben13] as an approximation. The basic idea is to replace g(x) = f'(x), where f'(x) is 
2 the derivative of f wrt input, with g(x) = x when computing the backwards pass. See Figure 6.5 for 
3. a visualization, and [Yin+19b] for an analysis of why this is a valid approximation. 


In practice, we sometimes replace g(x) = x with the hard tanh function, defined by 


x if-l<a<l 
HardTanh(v7)={1 ifa>1 (6.116) 
-1 ifa<-l 


45 This ensures the gradients that are backpropagated don’t get too large. See Section 21.6 for an 
46 application of this approach to discrete autoencoders. 
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6.6. BOUND OPTIMIZATION (MM) ALGORITHMS 


Forward Pass => 


= j '! l 
Threshold > X] 
Function 


mi. a 
GS 


Straight-Through Estimator 


= Backward Pass 


W1 W2 w4 


QOOOO 
QOOOO 
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Figure 6.5: Illustration of straight-through estimator when applied to a binary threshold function in the 
middle of an MLP. From https: //www. hassanaskary. com/ python/ pytorch/ deep% 20learning/ 2020/ 
09/ 19/ intuitive-ezplanation-of-straight-through-estimators. html. Used with kind permission of 
Hassan Askary. 


6.6 Bound optimization (MM) algorithms 


In this section, we consider a class of algorithms known as bound optimization or MM algorithms. 
In the context of minimization, MM stands for majorize-minimize. In the context of maximization, 
MM stands for minorize-maximize. There are many examples of MM algorithms, such as EM 
(Section 6.6.3), proximal gradient methods (Section 4.1), the mean shift algorithm for clustering 
[FH75; Che95; FT05], etc. For more details, see e.g., [HL04; Mail5; SBP17; Nad+19], 


6.6.1 The general algorithm 


In this section, we assume our goal is to maximize some function (0) wrt its parameters 0. The 
basic approach in MM algorithms is to construct a surrogate function Q(0, 0°) which is a tight 
lowerbound to ¢(@) such that Q(0, 0°) < £(@) and Q(0', 6") = ¢(6°). If these conditions are met, we 
say that Q minorizes £. We then perform the following update at each step: 


0t! = argmax Q(0, 0°) (6.117) 
0 


This guarantees us monotonic increases in the original objective: 
e(at*t) > Q(a**", 0°) > Q(0", 6) = £(0") (6.118) 


where the first inequality follows since Q(0°, 0’) is a lower bound on £(0*) for any 6’; the second 
inequality follows from Equation (6.117); and the final equality follows the tightness property. As a 
consequence of this result, if you do not observe monotonic increase of the objective, you must have 
an error in your math and/or code. This is a surprisingly powerful debugging tool. 

This process is sketched in Figure 6.6. The dashed red curve is the original function (e.g., the 
log-likelihood of the observed data). The solid blue curve is the lower bound, evaluated at 0°; this 
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Figure 6.6: Illustration of a bound optimization algorithm. Adapted from Figure 9.14 of [Bis06]. Generated 
by em_log_likelihood_ maz.ipynb. 


touches the objective function at 0°. We then set 6°*! to the maximum of the lower bound (blue 
curve), and fit a new bound at that point (dotted green curve). The maximum of this new bound 
becomes 0¢t?, etc. 


6.6.2 Example: logistic regression 


If £(8) is a concave function we want to maximize, then one way to obtain a valid lower bound is to 
use a bound on its Hessian, i.e., to find a negative definite matrix B such that H(@) > B. In this 
case, one can show (see [BCN18, App. B]) that 


(8) > £(6') + (0 — 6°)" g(6") + 58 - 6°)" B(O — 6") (6.119) 


27 where g(@") = V¢(@"). Therefore the following function is a valid lower bound: 


Q(0,0°) = 0" (g(0*) — Ba’) + 50"BO (6.120) 


31 The corresponding update becomes 


6+! = 6' -B~'g(6*) (6.121) 


34 This is similar to a Newton update, except we use B, which is a fixed matrix, rather than H(6‘), 
22 which changes at each iteration. This can give us some of the advantages of second order methods at 
2 lower computational cost. 


For example, let us fit a multi-class logistic regression model using MM. (We follow the presentation 


= of [Kri+05], who also consider the more interesting case of sparse logistic regression.) The probability 
= that example n belongs to class c € {1,...,C} is given by 


exp(w i £n) 


C 
Ži exp(w] £n) 


D(Yn = cl@n, w) = (6.122) 


44 Because of the normalization condition D PlYn = cl£n, w) = 1, we can set wo = 0. (For example, 
45 in binary logistic regression, where C = 2, we only learn a single weight vector.) Therefore the 
46 parameters 0 correspond to a weight matrix w of size D(C — 1), where x, € RP. 
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6.6. BOUND OPTIMIZATION (MM) ALGORITHMS 


If we let pp (w) = [plyn = Lan, w), ..., p(Yn = C—1|æn, w)] and yn = [I (yn = 1),..-, I (Yn = C — 1], 


we can write the log-likelihood as follows: 


N [C-1 c 
L(w) = 5 5 YncW. En = log $` explwl £n) (6.123) 
n=1 Le=1 c=1 
The gradient is given by the following: 
N 
g(w) = X (Yn — Pn(w)) 8 £n (6.124) 
n=1 


where &® denotes Kronecker product (which, in this case, is just outer product of the two vectors). 
The Hessian is given by the following: 


N 
H(w) = — S > (diag(pn(w)) — Pa(w)pn(w)") @(anxh) (6.125) 


n=1 


We can construct a lower bound on the Hessian, as shown in [Boh92]: 
1 N 
H(w) > —5[1-11°/C]@() znæn) ê B (6.126) 
n=1 


where I is a (C — 1)-dimensional identity matrix, and 1 is a (C — 1)-dimensional vector of all 1s. In 
the binary case, this becomes 


N 
1 1 1 
H —=(1-= £n) = ——X'X 12 
(wo) » 5. 5)() hn) = =] (6.127) 
This follows since pn < 0.5 so —(pn — p2) > —0.25. 
We can use this lower bound to construct an MM algorithm to find the MLE. The update becomes 
wt! = wt — Bo 'g(w’) (6.128) 


For example, let us consider the binary case, so g! = Vé(w’) = X! (y — u+), where ut = [pn (wt), (1 — 
Pn(w*))|*_,. The update becomes 


wt! = wt — 4(K™X)~ 1g! (6.129) 


The above is faster (per step) than the IRLS (iteratively reweighted least squares) algorithm (i.e., 
Newton’s method), which is the standard method for fitting GLMs. To see this, note that the Newton 
update has the form 


wt! = wt — H-1g(w’) = wt — (XTS*X) tg (6.130) 


where St = diag(p’ ©(1 — u*)). We see that Equation (6.129) is faster to compute, since we can 
precompute the constant matrix (X™X)7?. 
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6.6.3 The EM algorithm 


In this section, we discuss the expectation maximization (EM) algorithm [DLR77; MK07], which 
is an algorithm designed to compute the MLE or MAP parameter estimate for probability models 
that have missing data and/or hidden variables. It is a special case of an MM algorithm. 

The basic idea behind EM is to alternate between estimating the hidden variables (or missing 
values) during the E step (expectation step), and then using the fully observed data to compute the 
MLE during the M step (maximization step). Of course, we need to iterate this process, since the 
expected values depend on the parameters, but the parameters depend on the expected values. 

In Section 6.6.3.1, we show that EM is a bound optimization algorithm, which implies that this 
iterative procedure will converge to a local maximum of the log likelihood. The speed of convergence 
depends on the amount of missing data, which affects the tightness of the bound [XJ96; MD97; 
SRG03; KKS20]. 

We now describe the EM algorithm for a generic model. We let yn be the visible data for example 
n, and z, be the hidden data. 


6.6.3.1 Lower bound 


The goal of EM is to maximize the log likelihood of the observed data: 


Np Np 
= X log p(yn|9) = X` log 2 P(Yn, 0) (6.131) 


where Yn are the visible variables and z, are the hidden variables. Unfortunately this is hard to 
optimize, since the log cannot be pushed inside the sum. 

EM gets around this problem as follows. First, consider a set of arbitrary distributions g,(Z,) over 
each hidden variable z„. The observed data log likelihood can be written as follows: 


= Soy £ qn(Zn) Hes (6.132) 


Using Jensen’s inequality, we can push the log (which is a concave function) inside the expectation 


32 to get the following lower bound on the log likelihood: 


)> 2, dnl 2 leg Ae P(Yn 216) (6.133) 


Qn(Zn) 
= SCE, [log p(Yn, zn|0)] + H(Gn) (6.134) 
L(9,4n|yn) 
= XC EO, dnlyn) = £(9, {an} ID) (6.135) 


43 where H(q) is the entropy of probability distribution q, and L(0, {q,}|D) is called the evidence 


lower bound or ELBO, since it is a lower bound on the log marginal likelihood, log p(y1:n |0), also 


45 called the evidence. Optimizing this bound is the basis of variational inference, as we discuss in 
46 Section 10.1. 
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6.6.3.2 E step 


We see that the lower bound is a sum of N terms, each of which has the following form: 


P(Yn; Zn|9) 
0 n n n Zn) )lo voa a Te 6.136 
L(9, dn|Yn) =2 ml E ne ( ) 
P(2Zn|Yns 9)P(Yn|A) 
= Za) 6.137 
= al ETE ( ) 
P\Zn nI 
= Dalen) tt + F da(2n) log p(n 18) (6.138) 
= De (dn (Zn) || P(2n|Yn, 9)) + log p(yn|@) (6.139) 


where Dxt (q || p) = ©, q(z) log a is the Kullback-Leibler divergence (or KL divergence for short) 
between probability distributions q and p. We discuss this in more detail in Section 5.1, but the key 
property we need here is that Dpi (q || p) > 0 and Dxt (q || p) = 0 iff q = p. Hence we can maximize 
the lower bound Ł(0, {qn }|D) wrt {qn} by setting each one to q = p(Zn|Yn, 0). This is called the E 


step. This ensures the ELBO is a tight lower bound: 
L(0, {44}ID) = X` log p(yn|) = «0|D) (6.140) 


To see how this connects to bound optimization, let us define 


Q(9, 0°) = Ł(0, {p(znlyn; 9°) }) (6.141) 


Then we have Q(0, 6°) < £(@) and Q(0*, 0") = ¢(6"), as required. 

However, if we cannot compute the posteriors p(Zn|Yyn; 4") exactly, we can still use an approximate 
distribution q(Zn|Yn; 9°); this will yield a non-tight lower-bound on the log-likelihood. This generalized 
version of EM is known as variational EM [NH98b]. See Section 6.6.6.1 for details. 


6.6.3.3 M step 


In the M step, we need to maximize L(0, {q‘,}) wrt 8, where the qf are the distributions computed 
in the E step at iteration t. Since the entropy terms H(q,,) are constant wrt 0, we can drop them in 
the M step. We are left with 


(0) E 5 Bgt (zn) [log plyn, Zn |0)] (6.142) 


n 


This is called the expected complete data log likelihood. If the joint probability is in the 
exponential family (Section 2.3), we can rewrite this as 


(0) = X E [T (yn, zn)" — A(0)] = X (E [T (Yn, 2n)]' 0 — A(0)) (6.143) 


n n 


where E [T (yn, Zn)] are called the expected sufficient statistics. 
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In the M step, we maximize the expected complete data log likelihood to get 


0+! = arg max ) | Dg, [log plyn, Zn19)] (6.144) 


In the case of the exponential family, the maximization can be solved in closed-form by matching the 
moments of the expected sufficient statistics (Section 2.3.5). 

We see from the above that the E step does not in fact need to return the full set of posterior 
distributions {q(z,)}, but can instead just return the sum of the expected sufficient statistics, 
ae alza) [T (Yn, Zn)]. 

A common application of EM is for fitting mixture models; we discuss this in the prequel to this 
book, [Mur22]. Below we give a different example. 


6.6.4 Example: EM for an MVN with missing data 


It is easy to compute the MLE for a multivariate normal when we have a fully observed data matrix: 
we just compute the sample mean and covariance. In this section, we consider the case where we have 
missing data or partially observed data. For example, we can think of the entries of Y as being 
answers to a survey; some of these answers may be unknown. There are many kinds of missing data, 
as we discuss in Section 21.3.4. In this section, we make the missing at random (MAR) assumption, 
for simplicity. Under the MAR assumption, the log likelihood of the visible data has the form 


log p(X]0) = X log p(an|9) = S— log | / PlEn, 2,|0)dz,| (6.145) 


where x, are the visible variables in case n, Z,, are the hidden variables, and Yn = (Zn, £n) are all 
the variables. Unfortunately, this objective is hard to maximize. since we cannot push the log inside 
the expectation. Fortunately, we can easily apply EM, as we explain below. 


= 6.6.4.1 E step 


31 Suppose we have the parameters 0*7} from the previous iteration. Then we can compute the expected 
32 complete data log likelihood at iteration t as follows: 


Nop 
Q(9,0'') =E | X` log M (ynl, E)|D, 0" (6.146) 
= -p log [27 | — 2 5 [(Yn — H) E~ (yn — H)] (6.147) 
= -2P hog |215] — Ztr(27 SOE [un — 1) Yn — HY") (6.148) 
= -ÂP hog |5] — ŽP? tog(2n) — Str("E [S(4)) (6.149) 
where 
i [S(u)] ê >> (E [yoy] + eee” — 2HE [yn] ) (6.150) 
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1 
2 (We drop the conditioning of the expectation on D and 6‘! for brevity.) We see that we need to 
3 compute >>, E [yn] and „E [ynyr |; these are the expected sufficient statistics. 
4 To compute these quantities, we use the results from Section 2.2.5.3. We have 
5 
6 D(Zn|@n, 0) = N(Zn|Mn, Vn) (6.151) 
T Mnp = Hh T En Epe (En = Hy) (6.152) 
8 = 
A Vn Ê Ehr — Env Xy Xoh (6.153) 
10 where we partition u and & into blocks based on the hidden and visible indices h and v. Hence the 
u expected sufficient statistics are 
12 
13 © [yn] = (E [zn] ; En) = (Mn; Ln) (6.154) 
14 
15 To compute E [Yny], we use the result that Cov [y] = E [yy | — E [y] E ae Hence 
16 
17 z i [znz] E [zn] æ, 

x T x n T T nen nm n 

y = = 6.155 
18 [unun] (g (2a =) a a [zn] TnT, ( ) 
19 
20 ; [znZn] = E [zn] E [zn]" + Vn (6.156) 
21 
22 6.6.4.2 M step 
23 By solving VQ(0,0"'-)) = 0, we can show that the M step is equivalent to plugging these ESS into 
21 the usual MLE equations to get 
25 
26 i. 1 , 
27 H= Np [yn] (6.157) 
28 oe 
29 mY = No © [Ym Yn] — pC)" (6.158) 
30 D n 
31 
ae Thus we see that EM is not equivalent to simply replacing variables by their expectations and 
F applying the standard MLE formula; that would ignore the posterior variance and would result in an 
a4 incorrect estimate. Instead we must compute the expectation of the sufficient statistics, and plug 
ae that into the usual equation for the MLE. 
36 
37 6.6.4.3 Initialization 
38 To get the algorithm started, we can compute the MLE based on those rows of the data matrix that 
32 are fully observed. If there are no such rows, we can just estimate the diagonal terms of © using the 
2 observed marginal statistics. We are then ready to start EM. 
41 
42 
ae 6.6.4.4 Example 
44 As an example of this procedure in action, let us consider an imputation problem, where we have 
45 Np = 100 10-dimensional data cases, which we assume to come from a Gaussian. We generate 
46 synthetic data where 50% of the observations are missing at random. First we fit the parameters 
47 
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1 
p Imputation with true parameters Imputation with EM estimated parameters 
3 R^2 = 0.411 R^2 = 0.443 R^2 = 0.247 R^2 = 0.438 
J 1 24 Ms a Pr 2 * as j 
S g g E Pia * g 1 t t E ka a 
6 =l ğ F -14* xe 24, * 
7 -2 “1 o 1 = -2 o 2 -2 en o 1 2 
= Truth Truth Truth 
8 R*2 = 0.390 R*2 = 0.238 R*2 = 0.263 
9 2 : ¥ 27 ¥ a ¥ r S E ge 
* kot * * x * 
F go * tte? i a + 4 2S g ye By * $a 
10 a an * a ary * g a + z relies a w ` Mpk 
11 E -2 z “te E ty E213 S he E * 
2. a of a *™ alt er š 
* x * x 
12 -4 -2 o 2 -1 o 1 2 -4 -2 o 2 -1 0 1 2 
13 Truth Truth Truth Truth 
i (a) (b) 
15 
16 Figure 6.7: Illustration of data imputation using a multivariate Gaussian. (a) Scatter plot of true values vs 
17 imputed values using true parameters. (b) Same as (a), but using parameters estimated with EM. We just 
is show the first four variables, for brevity. Generated by gauss imputation _ em_demo.ipynb. 
19 
20 
21 ` 
22 using EM. Call the resulting parameters @. We can now use our model for predictions by computing 
23 E [enlen, J . Figure 6.7 indicates that the results obtained using the learned parameters are almost 
24 as good as with the true parameters. Not surprisingly, performance improves with more data, or as 
25 the fraction of missing data is reduced. 
26 
27 
23 6.6.5 Example: robust linear regression using Student likelihood 
29 
39 In this section, we discuss how to use EM to fit a linear regression model that uses the Student 
m distribution for its likelihood, instead of the more common Gaussian distribution, in order to achieve 
m robustness, as first proposed in [Zel76]. More precisely, the likelihood is given by 
33 2 T., 2 
34. P(yl@, w, o", v) = T(y|w z,o, v) (6.159) 
35 At first blush it may not be apparent how to do this, since there is no missing data, and there 
37 are no hidden variables. However, it turns out that we can introduce “artificial” hidden variables to 
3g make the problem easier to solve; this is a common trick. The key insight is that we can represent 
39 the Student distribution as a Gaussian scale mixture, as we discuss in Section 28.2.3.1. 
i We can apply the GSM version of the Student distribution to our problem by associating a latent 
41 scale zn € Ry with each example. The complete data log likelihood is therefore given by 
2 1 1 
2 2 ThA? 
43° logply,z|X,w,0^,v) = 5 —5 log(2nzno*) — zz (Yi — w` xi) (6.160) 
44 m 2 ZnO 
45 v v 
= + (= — 1) log(zn) — 2n= + const (6.161) 
46 2 2 
47 
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=m true log-likelihood š’ =m true log-likelihood 


0.6 = a lower bound 0.6 mu 8 lower bound 
0 1 2 3 4 0 1 2 3 4 
training time training time 
(a) (b) 


Figure 6.8: Illustration of possible behaviors of variational EM. (a) The lower bound increases at each 
iteration, and so does the likelihood. (b) The lower bound increases but the likelihood decreases. In this case, 
the algorithm is closing the gap between the approximate and true posterior. This can have a regularizing 
effect. Adapted from Figure 6 of [SJJ96]. Generated by var_em_ bound.ipynb. 


Ignoring terms not involving w, and taking expectations, we have 


Q(0,8) =->> 2 


(Yn = wen) (6.162) 


where àf £ E [1/2n|Yn, £n, wt]. We recognize this as a weighted least squares objective, with weight 
At per data point . 

We now discuss how to compute these weights. Using the results from Section 2.2.3.4, one can 
show that 


v+l1v+6, 


P(Zn|Yn; Ln, 0) = IG( 9° 2 ) (6.163) 
where ôn = (nw en)" is the standardized residual. Hence 
yeti 
An = z [1 nl m= ee 6.164 
z= SE (6.164) 


So if the residual ôf is large, the point will be given low weight At , which makes intuitive sense, since 
it is probably an outlier. 


6.6.6 Extensions to EM 


There are many variations and extensions of the EM algorithm, as discussed in [MK97]. We summarize 
a few of these below. 


6.6.6.1 Variational EM 


Suppose in the E step we pick q} = argming ¢9 Dx (qn || P(Zn|@n,@)). Because we are optimizing 
over the space of functions, this is called variational inference (see Section 10.1 for details). If the 
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family of distributions Q is rich enough to contain the true posterior, qn = p(Zn|%n, 0), then we can 
make the KL be zero. But in general, we might choose a more restrictive class for computational 
reasons. For example, we might use gn(Zn) = N(Zn|M,,,diag(on)) even if the true posterior is 
correlated. 

The use of a restricted posterior family Q inside the E step of EM is called variational EM 
[NH98a]. Unlike regular EM, variational EM is not guaranteed to increase the actual log likelihood 
itself (see Figure 6.8), but it does monotonically increase the variational lower bound. We can control 
the tightness of this lower bound by varying the variational family Q; in the limit in which qn = pn, 
corresponding to exact inference, we recover the same behavior as regular EM. See Section 10.2.5 for 
further discussion. 


6.6.6.2 Hard EM 


Suppose we use a degenerate posterior approximation in the context of variational EM, corresponding 
to a point estimate, q(z|@,) = oz, (z), where 2, = argmax, p(z|x,,). This is equivalent to hard 
EM, where we ignore uncertainty about zn in the E step. 

The problem with this degenerate approach is that it is very prone to overfitting, since the number 
of latent variables is proportional to the number of datacases [WCS08]. 


6.6.6.3 Monte Carlo EM 


Another approach to handling an intractable E step is to use a Monte Carlo approximation to the 


= expected sufficient statistics. That is, we draw samples from the posterior, z8 ~ p(2n|an, 0°), and 
= then compute the sufficient statistics for each completed vector, (£n, z), and then average the 
= results. This is called Monte Carlo EM or MCEM [WT90; Nea12]. 


One way to draw samples is to use MCMC (see Chapter 12). However, if we have to wait for 


= MCMC to converge inside each E step, the method becomes very slow. An alternative is to use 
= stochastic approximation, and only perform “brief” sampling in the E step, followed by a partial 
“= parameter update. This is called stochastic approximation EM [DLM99] and tends to work 


=< better than MCEM. 


6.6.6.4 Generalized EM 


za Sometimes we can perform the E step exactly, but we cannot perform the M step exactly. However, 


we can still monotonically increase the log likelihood by performing a “partial” M step, in which we 
merely increase the expected complete data log likelihood, rather than maximizing it. For example, 


~. we might follow a few gradient steps. This is called the generalized EM or GEM algorithm [MK07]. 


(This is an unfortunate term, since there are many ways to generalize EM, but it is the standard 


- terminology.) For example, [Lan95a] proposes to perform one Newton-Raphson step: 


Oiri =O. — mH; gi (6.165) 
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where 0 < m < 1 is the step size, and 
o 


g= 59 2 (9: 9:)|o=0, (6.166) 
82 
H; = gogg 2 (9: 9)lo-6, (6.167) 


If m = 1, [Lan95a] calls this the gradient EM algorithm. However, it is possible to use a larger 
step size to speed up the algorithm, as in the quasi-Newton EM algorithm of [Lan95b]. This 
method also replaces the Hessian in Equation (6.167), which may not be negative definite (for non 
exponential family models), with a BFGS approximation. This ensures the overall algorithm is an 
ascent algorithm. Note, however, when the M step cannot be computed in closed form, EM loses 
some of its appeal over directly optimizing the marginal likelihood with a gradient based solver. 


6.6.6.5 ECM algorithm 


The ECM algorithm stands for “expectation conditional maximization”, and refers to optimizing the 
parameters in the M step sequentially, if they turn out to be dependent. The ECME algorithm, 
which stands for “ECM either” [LR95], is a variant of ECM in which we maximize the expected 
complete data log likelihood (the Q function) as usual, or the observed data log likelihood, during 
one or more of the conditional maximization steps. The latter can be much faster, since it ignores 
the results of the E step, and directly optimizes the objective of interest. A standard example of 
this is when fitting the Student distribution. For fixed v, we can update © as usual, but then to 
update v, we replace the standard update of the form vtt! = arg max, Q((u'+!, St, v), 0") with 
utti = arg max, log p(D|uttt, 5'*!, v). See [MK97] for more information. 


6.6.6.6 Online EM 


When dealing with large or streaming datasets, it is important to be able to learn online, as we 
discussed in Section 19.7.5. There are two main approaches to online EM in the literature. The 
first approach, known as incremental EM [NH98al], optimizes the lower bound Q(6,qi,..., qn) 
one qn at a time; however, this requires storing the expected sufficient statistics for each data case. 

The second approach, known as stepwise EM [SI00; LK09; CM09], is based on stochastic gradient 
descent. This optimizes a local upper bound on £,(0) = log p(a,,|@) at each step. (See [Mail3; 
Mail5] for a more general discussion of stochastic and incremental bound optimization algorithms.) 


6.7 The Bayesian learning rule 


In this section, we discuss the “Bayesian learning rule” [KR21a], which provides a unified framework 
for deriving many standard (and non-standard) optimization and inference algorithms used in the 
ML community. 

To motivate the BLR, recall the standard empirical risk minimzation or ERM problem, which 


has the form 0, = argming (0), where 


N 
29) = X (yn, fo(an)) + R(0) (6.168) 
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Initialization 


Region with Loss 
a large loss 


Flat minima 
Sharp minima 


9.(6) 
6.-6 8 m, 0,46 > 9-(8) 
——————s e oP 
Large -ve Zero Small +ve — — 


gradient gradient gradient Baa Ma 0,2 ga 


(a) (b) 


Figure 6.9: Illustration of the robustness obtained by using a Bayesian approach to parameter estimation. (a) 
When the minimum 0, lies next to a “wall”, the Bayesian solution shifts away from the boundary to avoid 
large losses due to perturbations of the parameters. (b) The Bayesian solution prefers flat minima over sharp 
minima, to avoid large losses due to perturbations of the parameters. From Figure 1 of [KR21a]. Used with 
kind permission of Emtiyaz Khan. 


where f(x) is a prediction function, (y, ĝ) is a loss function, and R(@) is some kind of regularizer. 

Although the regularizer can prevent overfitting, the ERM method can still result in parameter 
estimates that are not robust. A better approach is to fit a distribution over possible parameter 
values, g(@). If we minimize the expected loss, we will find parameter settings that will work well 
even if they are slightly perturbed, as illustrated in Figure 6.9, which helps with robustness and 


= generalization. Of course, if the distribution q collapses to a single delta function, we will end up 
=" with the ERM solution. To prevent this, we add a penalty term, that measures the KL divergence 
= from q(@) to some prior 7(@) x exp(—R(@)). This gives rise to the following BLR objective: 


N 
L(q) = Eqo) 2 (Yn; fo(@n))| + Dri (4(0) || 70(8)) (6.169) 
n=1 
We can rewrite the KL term as 
Dre (4(9) || 70(8)) = Eq(o) [R(0)] — H(4(8)) (6.170) 
and hence can rewrite the BLR objective as follows: 
L£(q) = Exco) [€(8)| — H(q(9)) (6.171) 


=’ Below we show that different approximations to this objective recover a variety of different methods 


= in the literature. 


~ 6.7.1 Deriving inference algorithms from BLR 


45 In this section we show how to derive several different inference algorithms from BLR. (We discuss 
46 such algorithms in more detail in Chapter 10.) 
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6.7.1.1 Bayesian inference as optimization 


The BLR objective includes standard exact Bayesian inference as a special case, as first shown in 
[Opt88]. To see this, let us assume the loss function is derived from a log-likelihood: 


f(y, fo(x)) = —log p(y| fo(x)) (6.172) 
Let D = { (£n, Yn): n = 1: N} be the data we condition on. The Bayesian posterior can be written 
as 

N 


To(0) | J pun fo(an)) (6.173) 


n=l 


p(O|D) = Z(D) 


This can be derived by minimizing the BLR, since 


N 
L(q) = —Ey@) È log p(Yn| fo(#n))| + Dux (a(8) || To(8)) (6.174) 


n=1 


q(9) 
= Exo) | los am log Z(D) (6.175) 


N 
Z(D) 1 kee P(Yn| fo(@n)) 
= Dra (9(9) || p(@|D)) — log Z(D) (6.176) 
Since Z(D) is a constant, we can minimize the loss by by setting q(@) = p(O|D). 
Of course, we can use other kinds of loss, not just log likelihoods. This results in a framework 


known as generalized Bayesian inference [BHW16; KJD19; KJD21]. See Section 14.1.3 for more 
discussion. 


6.7.1.2 Optimization of BLR. using natural gradient descent 


In general, we cannot compute the exact posterior g(@) = p(@|D), so we seek an approximation. We 
will assume that q(0) is an exponential family distibution, such as a multivariate Gaussian, where 
the mean represents the standard point estimate of O (as in ERM), and the covariance represents our 
uncertainty (as in Bayes). Hence q can be written as follows: 


q(@) = h(8) exp[A'T(@) — A(A)] (6.177) 


where A are the natural parameters, 7(@) are the sufficient statistics, A(A) is the log partition 
function, and h(@) is the base measure, which is usually a constant. The BLR loss becomes 


L(A) = Eq, o) [€(8)] — H(aa(9)) (6.178) 
We can optimize this using natural gradient descent (Section 6.4). The update becomes 
Atti = Ae hV Aa | ta, (8) [£(0)] — H(a.)| (6.179) 


where V, denotes the natural gradient. We discuss how to compute these natural gradients in 
Section 6.4.5. In particular, we can convert it to regular gradients wrt the moment parameters 
Hi = H(A). This gives 


Ai = At — MV Ey, (0) AO] + mV. Elaa) (6.180) 
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From Equation (6.84) we have 
Va H(q) = —A — VyuEy,,(6) [log h(9)] (6.181) 
Hence the update becomes 
Atti = At = mV p Udy (0) 2(8)| = mrt = mV p Vay (9) flog h(@)| (6.182) 
= (1—m)A: — mV pEqu (6) ECO) + log h(0)] (6.183) 
For distributions q with constant base measure h(@), this simplifies to 
At+1 = (1 = 1) Mt s mV p qu 0) [2(@)| (6.184) 
Hence at the fixed point we have 
A. = (L—1)Ax — nV „Eg, (0) [€(9)] (6.185) 
Ax = VyEq, (6) [-2(4)| = VAEq, (0) [-2(9)] (6.186) 
6.7.1.3 Conjugate variational inference 
In Section 7.3 we show how to do exact inference in conjugate models. We can derive Equation (7.11) 
from the BLR by using the fixed point condition in Equation (6.186) to write 
N 
As = V pE, [-80)] = Ao + X V pE [log p(yi|6)] (6.187) 
a. 
: Aulus) 


where A;(y;) are the sufficient statistics for the i'th likelihood term. 

For models where the joint distribution over the latents factorizes (using a graphical model), we 
can further decompose this update into a series of local terms. This gives rise to the variational 
message passing scheme discussed in Section 10.2.7. 


6.7.1.4 Partially conjugate variational inference 


31 In Supplementary Section 10.3.1, we discuss CVI, which performs variational inference for partially 
32 conjugate models, using gradient updates for the non-conjugate parts, and exact Bayesian inference 
33 for the conjugate parts. 


2 6.7.2 Deriving optimization algorithms from BLR 


In this section we show how to derive several different optimization algorithms from BLR. Recall 
that in BLR, instead of directly minimizing the loss 
N 
08) = X Uun, fo(an)) + R(0) (6.188) 
n=1 


42 we will instead minimize 


L(A) = Eqjay [2(8)] — H(q(9|A)) (6.189) 


45 Below we show that different approximations to this objective recover a variety of different optimization 
46 methods that are used in the literature. 
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6.7. THE BAYESIAN LEARNING RULE 


6.7.2.1 Gradient descent 


In this section, we show how to derive gradient descent as a special case of BLR. We use as our 
approximate posterior q(0) = N (0|m,I). In this case the natural and moment parameters are equal, 
u =A= m. The base measure satisfies the following (from Equation (2.180)): 


2log h(@) = -D log(2r) — 0'0 (6.190) 
Hence 
VE, [log h(0)] = Va (D log(27) — u' u — D) = -pu = -À = -m (6.191) 


Thus from Equation (6.183) the BLR update becomes 


Mi+1 = (1 = m) M¢ + MM, — MV m “dm (9) [2(@)| (6.192) 


We can remove the expectation using the first order delta method (Section 6.5.5): 


Vm “dm (8) [2(9)| Re Vol(0)\e—m (6.193) 


Putting these together gives the gradient descent update: 
Miti = Mi — mVol(9)\o=m, (6.194) 


6.7.2.2 Newton’s method 


In this section, we show how to derive Newton’s second order optimization method as a special case 
of BLR, as first shown in [Kha+18]. 
Suppose we assume q(0) = .V(@|m,S~'). The natural parameters are 


1 
AY = Sm, A?) = =s (6.195) 
The mean (moment) parameters are 
u® =m, p?) = S7! + mm! (6.196) 


Since the base measure is constant (see Equation (2.193)), from Equation (6.184) we have 


SiM = (1 — 4) Sim — MV po Eac) [2(9)] (6.197) 
Siti = (1 — m)S: + 24V po Eg, (6) [2(9)| (6.198) 

In Section 6.4.5.1 we show that 
V,a Ego) EO)] = Ero) [VEZO] — Eye [V32(6)] m (6.199) 
Vu Eqo) [€(8)] = 24(0) [Voe(9)] (6.200) 


Hence the update for the precision matrix becomes 


Set = (1 — )Se + mEq, [V3EO)] (6.201) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


272 
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2 For the precision weighted mean, we have 

3 = = 

4 Siti Mipi = (1 = M) Sem; — M ie [Vol(9)| + Nt Lae [Vel(9)] mM (6.202) 
5 = Siim — t vat [Vol(8)| (6.203) 
6 

7 Hence 

8 = 

9 Mii = Mi — MSh Eq, [Vol(9)] (6.204) 
10 

11 We can recover Newton’s method in three steps. First set the learning rate to m = 1, based on 
12 an assumption that the objective is convex. Second, treat the iterate as m = 0,. Third, apply the 
13 delta method to get 

l4 z — 

15 St41 = Eq, [V5 0(8)] ~ V62(@)o=m, (6.205) 
16 

17 and 

18 : = - 

19 vq [Vol(8)| = Vol(O)\e=m, (6.206) 
F This gives Newton’s update: 

= Mipi =m — [VAMA] V mim] (6.207) 
24 


6.7.2.3 Variational online Gauss-Newton 


= In this section, we describe the Variational Online Gauss-Newton or VOGN method of |Kha+18]. 
= This is an approximate second order optimization method that can be derived from the BLR in 
= several steps, as we show below. 


First, we use a diagonal Gaussian approximation to the posterior, q,(@) = M (0|0+, S7"), where 
= S; = diag(s;) is a vector of precisions. Following Section 6.7.2.2, we get the following updates: 
O1 = Qi — h © Ey, [Vo4(0)] (6.208) 
St+1 
S141 = (L—m)8¢ + mEq, [diag(V34(0))] (6.209) 


— where © is elementwise multiplication, and the division by s;41 is also elementwise. 


Second, we use the delta approximation to replace expectations by plugging in the mean. Third 


— we use a minibatch approximation to the gradient and diagonal Hessian: 
FLO) = ~ Y Volly: folwi)) + VeRO) (6.210) 
(G ZM. ot \Yi, JolTi (2 . 
iEM 
Ain N 
V5, (8) = IT D V3, Lui, fo(wi)) + Vo, RO) (6.211) 


46 where M is the minibatch size. 
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6.7. THE BAYESIAN LEARNING RULE 


For some non-convex problems, such as DNNs, the Hessian may be not be positive definite, so we 
can get better results using a Gauss-Newton approximation, based on the squared gradients instead 
of the Hessian: 

¥3,0(0) © © S [Vo tui, folæ:))]? + V3, RO) (6.212) 
icM 
This is also faster to compute. 

Putting all this together gives rise to the Online Gauss-Newton or OGN method of [Osa+ 1a]. 

If we drop the delta approximation, and work with expectations. we get the Variational Online 
Gauss-Newton or VOGN method of [Kha+18]. We can approximate the expectations by sampling. 
In particular, VOGN uses the following weight perturbation method 


a Vor) | ~ Wol(O, + €) (6.213) 


where e, ~ N (0, diag(s;)). It also also possible to approximate the Fisher information matrix directly; 
this results in the Variational Online Generalized Gauss-Newton or VOGGN method of 
[Osa+19al. 


6.7.2.4 Adaptive learning rate SGD 


In this section, we show how to derive an update rule which is very similar to the RMSprop [Hin14] 
method, which is widely used in deep learning. The approach we take is similar to that VOGN in 
Section 6.7.2.3. We use the same diagonal Gaussian approximation, q,(@) = N(6|6,,S; 1), where 
S; = diag(s;) is a vector of precisions. We then use the delta method to eliminate expectations: 


1 = 

Or44 = 0; — MN © Vol(@:) (6.214) 
St+1 

8:41 = (1 — m)8¢ + mdiag(VZ2(4,)) (6.215) 


where © is elementwise multiplication. If we allow for different learning rates we get 


6:41 = Oi — r © Vol(O:) (6.216) 


St41 
St+1 = (1 = G4) St + Br diag(VZ0(41)) (6.217) 


Now suppose we replace the diagonal Hessian approximation with the sum of the squares per-sample 
gradients: 


diag(V30(0,)) ~ V2(O,) © VE(O;) (6.218) 


If we also change some scaling factors we can get the RMSprop updates: 


1 PES 
Ot+1 = 0: = ” Joe ell © Vol(O:) (6.219) 
vı = (1— B)u% + LIVE) © V2(O,)] (6.220) 


This allows us to use standard deep learning optimizers to get a Gaussian approximation to the 
posterior for the parameters [Osa+19a]. 

It is also possible to derive the Adam optimizer [KB15] from BLR by adding a momentum term to 
RMSprop. See [KR21a; Ait18] for details. 
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6.7.3 Variational optimization 


Consider an objective defined in terms of discrete variables. Such objectives are not differentiable and 
so are hard to optimize. One advantage of BLR is that it optimizes the parameters of a probability 
distribution, and such expected loss objectives are usually differentiable and smooth. This is called 
“variational optimization” |Bar17], since we are optimizing over a probability distribution. 

For example, consider the case of a binary neural network where ĝa € {0,1} indicates if 
weight d is used or not, we can optimize over the parameters of a Bernoulli distribution, q(@|A) = 
eae Ber(@qa|pa), where pq € [0,1] and Ag = log(pa/(1 — pa)) is the log odds. This is the basis of the 
BayesBiNN approach [MBK20]. 

If we ignore the entropy and regularizer term, we get the following simplified objective: 


L(A) = J 2(0)q(0|A)d0 (6.221) 


This method has various names: stochastic relaxation [SB12; SB13; MMP13], stochastic ap- 
proximation [HHC12; Hu+12], etc. It is closely related to evolutionary strategies, which we 
discuss in Section 6.9.6. 

In the case of functions with continuous domains, we can use a Gaussian for q(0|u, ©). The 
resulting integral in Equation (6.221) can then sometimes be solved in closed form, as explained in 
[Mob16]. By starting with a broad variance, and gradually reducing it, we hope the method can 
avoid poor local optima, similar to simulated annealing (Section 12.9.1. However, we generally get 
better results by including the entropy term, because then we can automatically learn to adapt the 
variance. In addition, we can often work with natural gradients, which results in faster convergence. 


27 6.8 Bayesian optimization 


29 In this section, we discuss Bayesian optimization or BayesOpt, which is a model-based approach 
30 to black-box optimization, designed for the case where the objective function f : Æ — R is expensive 
31 to evaluate (e.g., if it requires running a simulation, or training and testing a particular neural net 
32 architecture). 


Since the true function f is expensive to evaluate, we want to make as few function calls (i-e., make as 


34 few queries zx to the oracle f) as possible. This suggests that we should build a surrogate function 
35 (also called a response surface model) based on the data collected so far, Dn = {(@i, yi): i = 1: n}, 
36 which we can use to decide which point to query next. There is an inherent tradeoff between picking 
37 the point a where we think f(a) is large (we follow the convention in the literature and assume 
38 we are trying to maximize f), and picking points where we are uncertain about f(x) but where 
39 observing the function value might help us improve the surrogate model. This is another instance of 
40 the exploration-exploitation dilemma. 


In the special case where the domain we are optimizing over is finite, so ¥ = {1,..., A}, the 


42 BayesOpt problem becomes similar to the best arm identification problem in the bandit literature 
43 (Section 34.4). An important difference is that in bandits, we care about the cost of every action we 
44 take, whereas in optimization, we usually only care about the cost of the final solution we find. In 
45 other words, in bandits, we want to minimize cumulative regret, whereas in optimization we want to 
46 minimize simple or final regret. 
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6.8. BAYESIAN OPTIMIZATION 


Another related topic is active learning. Here the goal is to identify the whole function f with 
as few queries as possible, whereas in BayesOpt, the goal is just to identify the maximum of the 
function. 

Bayesian optimization is a large topic, and we only give a brief overview below. For more details, 
see e.g., [Sha+16; Fral8; Gar22]. (See also https: //distill.pub/2020/bayesian-optimization/ 
for an interactive tutorial.) 


6.8.1 Sequential model-based optimization 


BayesOpt is an instance of a strategy known as sequential model-based optimization (SMBO) 
[HHLB11]. In this approach, we alternate between querying the function at a point, and updating 
our estimate of the surrogate based on the new data. More precisely, at each iteration n, we have a 
labeled dataset D, = {(ai, yi) : i = 1: n}, which records points x; that we have queried, and the 
corresponding function values y; = f(a;) + €i, where e; is an optional noise term. We use this data to 
estimate a probability distribution over the true function f; we will denote this by p(f|D,). We then 
choose the next point to query 2,41 using an acquisition function a(x;D,,), which computes the 
expected utility of querying æ. (We discuss acquisition functions in Section 6.8.3). After we observe 
Yn+1 = f(@n+1) + €n41, we update our beliefs about the function, and repeat. See Algorithm 5 for 
some pseudocode. 


Algorithm 5: Bayesian optimization 


1 Collect initial dataset Do = {(x:, y:)} from random queries æ; or a space-filling design 
2 Initialize model by computing p(f|Do) 

3 for n =1,2,... until convergence do 

Choose next query point n41 = argmax,¢y alx; Dn) 

Measure function value, yn+1 = f (En+1) + En 

Augment dataset, Dr41 = {Dn, (@n41, Yn+1)} 

Update model by computing p(f|Dn+1) 


No a fp 


This method is illustrated in Figure 6.10. The goal is to find the global optimum of the solid black 
curve. In the first row, we show the 2 previously queried points, x; and z2, and their corresponding 
function values. yı = f(a1) and y2 = f(#2). Our uncertainty about the value of f at those locations 
is 0 (if we assume no observation noise), as illustrated by the posterior credible interval (shaded 
blue are) becoming “pinched”. Consequently the acquisition function (shown in green at the bottom) 
also has value 0 at those previously queried points. The red triangle represents the maximum of the 
acquisition function, which becomes our next query, x3. In the second row, we show the result of 
observing y3 = f (#3); this further reduces our uncertainty about the shape of the function. In the 
third row, we show the result of observing y4 = f(x4). This process repeats until we run out of time, 
or until we are confident there are no better unexplored points to query. 

The two main “ingredients” that we need to provide to a BayesOpt algorithm are (1) a way to 
represent and update the posterior surrogate p(f|D,), and (2) a way to define and optimize the 
acquisition function a(a;D,,). We discuss both of these topics below. 
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Figure 6.10: Illustration of sequential Bayesian optimization over three iterations. The rows correspond to a 
training set of size t = 2,3,4. The solid black line is the true, but unknown, function f(x). The dotted black 
line is the posterior mean, u(x). The shaded blue intervals are the 95% credible interval derived from u(x) 
and o(x). The solid black dots correspond to points whose function value has already been computed, i.e., 
Zn for which f(an) is known. The green curve at the bottom is the acquisition function. The red dot is the 
proposed next point to query, which is the maximum of the acquisition function. From Figure 1 of [Sha+16]. 
Used with kind permission of Nando de Freitas. 


-Z 6.8.2 Surrogate functions 


28 In this section, we discuss ways to represent and update the posterior over functions, p(f|D,). 


— 6.8.2.1 Gaussian processes 


32 In BayesOpt, it is very common to use a Gaussian process or GP for our surrogate. GPs are 
33 explained in detail in Chapter 18, but the basic idea is that they represent p(f(a)|D,,) as a Gaussian, 
34 p(f(x)|Dn) = N (f|un(£), o2 (£)), where n(x) and o,,(a) are functions that can be derived from 
35 the training data Dn = {(xi,y:) : i = 1: n} using a simple closed-form equation. The GP requires 
36 specifying a kernel function g(x, x’), which measures similarities between input points x, x’. The 
37 intuition is that if two inputs are similar, so K(x, x’) is large, then the corresponding function values 
38 are also likely to be similar, so f(x) and f(x’) should be positively correlated. This allows us to 
39 interpolate the function between the labeled training points; in some cases, it also lets us extrapolate 
40 beyond them. 


GPs work well when we have little training data, and they support closed form Bayesian updating. 
However, exact updating takes O(N?) for N samples, which becomes too slow if we perform many 


43 function evaluations. There are various methods (Section 18.5.3) for reducing this to O(N M?) time, 
44 where M is a parameter we choose, but this sacrifices some of the accuracy. 


In addition, the performance of GPs depends heavily on having a good kernel. We can estimate the 


46 kernel parameters 0 by maximizing the marginal likelihood, as discussed in Section 18.6.1. However, 
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6.8. BAYESIAN OPTIMIZATION 


since the sample size is small (by assumption), we can often get better performance by marginalizing 
out 0 using approximate Bayesian inference methods, as discussed in Section 18.6.2. See e.g., [WF16] 
for further details. 


6.8.2.2 Bayesian neural networks 


A natural alternative to GPs is to use a parametric model. If we use linear regression, we can 
efficiently perform exact Bayesian inference, as shown in Section 15.2. If we use a nonlinear model, 
such as a DNN, we need to use approximate inference methods. We discuss Bayesian neural networks 
in detail in Chapter 17. For their application to BayesOpt, see e.g. [Spr+16]. 


6.8.2.3 Other models 


We are free to use other forms of regression model. [HHLB11] use an ensemble of random forests; 
such models can easily handle conditional parameter spaces, as we discuss in Section 6.8.4.2, although 
bootstrapping (which is needed to get uncertainty estimates) can be slow. 


6.8.3 Acquisition functions 


In BayesOpt, we use an acquisition function (also called a merit function) to evaluate the 
expected utility of each possible point we could query: a(2|Dn) = Ep(y|a,p,,) (U(x, y; Dn)], where 
y = f(x) is the unknown value of the function at point æ, and U() is a utility function. Different 
utility functions give rise to different acquisition functions, as we discuss below. We usually choose 
functions so that the utility of picking a point that has already been queried is small (or 0, in the 
case of noise-free observations), in order to encourage exploration. 


6.8.3.1 Probability of improvement 


Let us define V, = max?_, y; to be the best value observed so far (known as the incumbent). (If the 
observations are noisy, using the highest mean value max; E,, ¢/p,,) |f(#i)] is a reasonable alternative 
[WF 16].) Then we define the utility of some new point x using U (æ, y; Dn) =I (y > Va). This gives 
reward iff the new value is better than the incumbent. The corresponding acquisition function is then 
given by the expected utility, apr (æ; Dn) = p(f(x) > Vn|Dn). This is known as the probability of 
improvement [Kus64]. If p(f|D,,) is a GP, then this quantity can be computed in closed form, as 
follows: 


apr(#;Dn) = p(f (£) > ValPn) = &(7(@, Vn) (6.222) 
where ® is the cdf of the N/(0,1) distribution and 


= Mn(@) -T 


R (6.223) 


(@,T) 


6.8.3.2 Expected improvement 


The problem with PI is that all improvements are considered equally good, so the method tends 
to exploit quite aggressively [Jon01]. A common alternative takes into account the amount of 
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Figure 6.11: The first row shows the objective function, (the Branin function defined on R?), and its posterior 
mean and variance using a GP estimate. White dots are the observed data points. The second row shows 
3 different acquisition functions (probability of improvement, expected improvement, and upper confidence 
bound); the white triangles are the maxima of the corresponding acquisition functions. From Figure 6 of 
[BCF10]. Used with kind permission of Nando de Freitas. 


improvement by defining U(x, y; Dn) = (y — Vn)I (y > Vn) and 


anr(æ; Dn) = Ep, [U(@,y)] = Ep, ((F(@) — Va) (f(@) > Vn) (6.224) 


— This acquisition function is known as the expected improvement (EI) criterion [Moc+96]. In the 
= case of a GP surrogate, this has the following closed form expression: 


apr (@;Dn) = (Mn(@) — Vn)®(y(@)) + on(#)O(7(#, Vn) (6.225) 


36 where ¢() is the pdf of the M (0, 1) distribution. The first term encourages exploitation (evaluating 
37 points with high mean) and the second term encourages exploration (evaluating points with high 
38 variance). This is illustrated in Figure 6.10. 


~ 6.8.3.3 Upper confidence bound (UCB) 


42 An alternative approach is to compute an upper confidence bound or UCB on the function, at 
43 some confidence level 6,,, and then to define the acquisition function as follows: aycs(æ; Dn) = 
44 Un(£) + Byon(x). This is the same as in the contextual bandit setting, discussed in Section 34.4.5, 
45 except we are optimizing over x € ¥, rather than a finite set of arms a € {1,..., A}. If we use a GP 
46 for our surrogate, the method is known as GP-UCB [Sri+ 10]. 
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6.8. BAYESIAN OPTIMIZATION 


6.8.3.4 Thompson sampling 


We discuss Thompson sampling in Section 34.4.6 in the context of multi-armed bandits, where 
the state space is finite, ¥ = {1,..., A}, and the acquisition function a(a;D,,) corresponds to the 
probability that arm a is the best arm. We can generalize this to real-valued input spaces ¥ using 


a(x; Dn) = E,ve\p,) i (« = argmax fol’) )| (6.226) 


We can compute a single sample approximation to this integral by sampling 6 ~ p(O|D,,). We can 
then pick the optimal action as follows: 


Ln41 = argmax a(x; Dn) = argmax I (« = argmax falz’) = argmax f(x) (6.227) 


r 


In other words, we greedily maximize the sampled surrogate. 

For continuous spaces, Thompson sampling is harder to apply than in the bandit case, since we 
can’t directly compute the best “arm” £n+1ı from the sampled function. Furthermore, when using 
GPs, there are some subtle technical difficulties with sampling a function, as opposed to sampling 
the parameters of a parametric surroagate model (see [HLHG14] for discussion). 


6.8.3.5 Entropy search 


Since our goal in BayesOpt is to find x* = argmax,, f(x), it makes sense to try to directly minimize 
our uncertainty about the location of x*, which we denote by p,(x|D,,). We will therefore define the 
utility as follows: 


U(x, y; Dn) = H (£* Dn) — H (x* [Dn U { (æ, y)}) (6.228) 


where H (x*|D,,) = H (p.(a|D,,)) is the entropy of the posterior distribution over the location of the 
optimum. This is known as the information gain criterion; the difference from the objective used in 
active learning is that here we want to gain information about «x* rather than about f for all x. The 
corresponding acquisition function is given by 


ans (#;Dn) = Ep(yle,D,) [U (£, y; Dn)] = H (#*|Dn) — Epyje,D,) H (@*|Dn U{(ae,u)})] (6.229) 


This is known as entropy search [HS12]. 

Unfortunately, computing H (a*|D,,) is hard, since it requires a probability model over the input 
space. Fortunately, we can leverage the symmetry of mutual information to rewrite the acquisition 
function in Equation (6.229) as follows: 


aprs(&;Dn) = H (y|Dn, x) — la*|Dn [H (y|Dn, £, x*)] (6.230) 
where we can approximate the expectation from p(x*|D,,) using Thompson sampling. Now we just 
have to model uncertainty about the output space y. This is known as predictive entropy search 


[HLHG 14]. 
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6.8.3.6 Knowledge gradient 


So far the acquisition functions we have considered are all greedy, in that they only look one step 
ahead. The knowledge gradient acquisition function, proposed in [FPD09], looks two steps ahead 
by considering the improvement we might expect to get if we query x, update our posterior, and 
then exploit our knowledge by maximizing wrt our new beliefs. More precisely, let us define the best 
value we can find if we query one more point: 


Vn+1 (£, y) = mos Un(fle,y,Dn) [f(x] (6.231) 
Vn41(@) = Epiyjæ, Dn) [Vn+1 (£, y)] (6.232) 


We define the KG acquisition function as follows: 


aralz; Dn) = Ep, [(Vn+1 (æ) — Va) (Vn4i(#) > Va)] (6.233) 


Compare this to the EI function in Equation (6.224).) Thus we pick the point x,+1 such that 
observing f(£n+1) will give us knowledge which we can then exploit, rather than directly trying to 
find a better point with better f value. 


6.8.3.7 Optimizing the acquisition function 


The acquisition function a(x) is often multimodal (see e.g., Figure 6.11), since it will be 0 at all the 
previously queried points (assuming noise-free observations). Consequently maximizing this function 
can be a hard subproblem in itself [WHD18; Rub+20]. 

In the continuous setting, it is common to use multi-restart BFGS or grid search. We can also use 
the cross-entropy method (Section 6.9.5), using mixtures of Gaussians [BK10] or VAEs [Fau+18] as 
the generative model over æ. In the discrete, combinatorial setting (e.g., when optimizing biological 
sequences), [Bel+19] use regularized evolution, (Section 6.9.3), and [Ang+20] use proximial policy 
optimization (Section 35.3.4). Many other combinations are possible. 


31 6.8.4 Other issues 


There are many other issues that need to be tackled when using Bayesian optimization, a few of 


>, Which we briefly mention below. 


6.8.4.1 Parallel (batch) queries 


=- In some cases, we want to query the objective function at multiple points in parallel; this is known as 
°° batched Bayesian optimization. Now we need to optimize over a set of possible queries, which is 
= computationally even more difficult than the regular case. See [WHD18; DBB20] for some recent 
= papers on this topic. 


— 6.8.4.2 Conditional parameters 


44 BayesOpt is often applied to hyper-parameter optimization. In many applications, some hyperparam- 
45 eters are only well-defined if other ones take on specific values. For example, suppose we are trying 
46 to automatically tune a classifier, as in the Auto-Sklearn system [Feu-+15], or the Auto-Weka 
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system [Kot+17]. If the method chooses to use a neural network, it also needs to specify the number 
of layers, and number of hidden units per layer; but if it chooses to use a decision tree, it instead 
should specify different hyperparameters, such as the maximum tree depth. 

We can formalize such problems by defining the search space in terms of a tree or DAG (directed 
acyclic graph), where different subsets of the parameters are defined at each leaf. Applying GPs to 
this setting requires non-standard kernels, such as those discussed in [Swe+13; Jen+17]. Alternatively, 
we can use other forms of Bayesian regression, such as ensembles of random forests [HHLB11], which 
can easily handle conditional parameter spaces. 


6.8.4.3 Multi-fidelity surrogates 


In some cases, we can construct surrogate functions with different levels of accuracy, each of which 
may take variable amounts of time to compute. In particular, let f(a,s) be an approximation to 
the true function at x with fidelity s. The goal is to solve max, f(x,0) by observing f(x, s) at a 
sequence of (a;, 5;) values, such that the total cost 7)", c(s;) is below some budget. For example, in 
the context of hyperparameter selection, s may control how long we run the parameter optimizer for, 
or how large the validation set is. 

In addition to choosing what fidelity to use for an experiment, we may choose to terminate 
expensive trials (queries) early, if the results of their cheaper proxies suggest they will not be worth 
running to completion (see e.g., [Str19; Li+17c; FKH17]). Alternatively, we may choose to resume 
an earlier aborted run, to collect more data on it, as in the freeze-thaw algorithm [SSA14]. 


6.8.4.4 Constraints 


If we want to maximize a function subject to known constraints, we can simply build the constraints 
into the acquisition function. But if the constraints are unknown, we need to estimate the support 
of the feasible set in addition to estimating the function. In [GSA14], they propose the weighted 
EI criterion, given by Qwe1(@; Dn) = agr(x; Dna)h(x; Dn), where h(x;D,,) is a GP with a Bernoulli 
observation model that specifies if æ is feasible or not. Of course, other methods are possible. For 
example, [HL-+16b] propose a method based on predictive entropy search. 


6.9 Derivative free optimization 


Derivative free optimization or DFO refers to a class of techniques for optimizing functions 
without using derivatives. This is useful for blackbox function optimization as well as discrete 
optimization. If the function is expensive to evaliate, we can use Bayesian optimization (Section 6.8). 
If the function is cheap to evaluate, we can use stochastic local search methods or evolutionary search 
methods, as we discuss below. 


6.9.1 Local search 


In this section, we discuss heuristic optimization algorithms that try to find the global maximum 
in a discrete, unstructured search space. These algorithms replace the local gradient based update, 
which has the form 0+1 = 0: + med, with the following discrete analog: 


L141 = argmax L(x) (6.234) 
xEnbr(x+) 
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where nbr(a;) C & is the set of neighbors of x. This is called hill climbing, steepest ascent, 
or greedy search. 

If the “neighborhood” of a point contains the entire space, Equation (6.234) will return the global 
optimum in one step, but usually such a global neighborhood is too large to search exhaustively. 
Consequently we usually define local neighborhoods. For example, consider the 8-queens problem. 
Here the goal is to place queens on an 8 x 8 chessboard so that they don’t attack each other (see 
Figure 6.14). The state space has the form ¥ = 64°, since we have to specify the location of each 
queen on the grid. However, due to the constraints, there are only 8° ~ 17M feasible states. We 
define the neighbors of a state to be all possible states generated by moving a single queen to another 
square in the same column, so each node has 8 x 7 = 56 neighbors. According to [RN10, p.123], if we 
start at a randomly generated 8-queens state, steepest ascent gets stuck at a local maximum 86% of 
the time, so it only solves 14% of problem instances. However, it is fast, taking an average of 4 steps 
when it succeeds and 3 when it gets stuck. 

In the sections below, we discuss slightly smarter algorithms that are less likely to get stuck in 
local maxima. 


6.9.1.1 Stochastic local search 


Hill climbing is greedy, since it picks the best point in its local neighborhood, by solving Equa- 
tion (6.234) exactly. One way to reduce the chance of getting stuck in local maxima is to approximately 
maximize this objective at each step. For example, we can define a probability distribution over the 
uphill neighbors, proportional to how much they improve, and then sample one at random. This is 
called stochastic hill climbing. If we gradually decrease the entropy of this probability distribution 
(so we become greedier over time), we get a method called simulated annealing, which we discuss in 
Section 12.9.1. 

Another simple technique is to use greedy hill climbing, but then whenever we reach a local 


28 maximum, we start again from a different random starting point. This is called random restart 


hill climbing. To see the benefit of this, consider again the 8-queens problem. If each hill-climbing 
search has a probability of p ~ 0.14 of success, then we expect to need R = 1/p ~ 7 restarts until we 


31 find a valid solution. The expected number of total steps can be computed as follows. Let N, = 4 


be the average number of steps for successful trials, and No = 3 be the average number of steps for 


33 failures. Then the total number of steps on average is Nı + (R — 1)No = 4 + 6 x 3 = 22. Since each 


step is quick, the overall method is very fast. For example, it can solve an n-queens problem with 


35 n =1M in under a minute. 


Of course, solving the n-queens problem is not the most useful task in practice. However, it is 
typical of several real-world boolean satisfiability problems, which arise in problems ranging 


33 from AI planning to model checking (see e.g., [SLM92]). In such problems, simple stochastic local 


search (SLS) algorithms of the kind we have discussed work surprisingly well (see e.g., [HS05]). 


= 6.9.1.2 Tabu search 


43 Hill climbing will stop as soon as it reaches a local maximum or a plateau. Obviously one can perform 
44 a random restart, but this would ignore all the information that had been gained up to this point. A 
45 more intelligent alternative is called tabu search [GL97]. This is like hill climbing, except it allows 
46 moves that decrease (or at least do not increase) the scoring function, provided the move is to a new 
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Algorithm 6: Tabu search. 


1 t:= 0 // counts iterations 

2 c:= 0 // counts number of steps with no progress 
3 Initialize £o 

4 £* := £o // current best incumbent 

5 while c < Cmax do 

6 Lt41 = argMaXs cnbr(a,)\{ae—7,-..,2+—1} f(x) 
7 if f(vi41) > f(x*) then 

8 g“ := L141 

9 = 
10 else 

11 | e:=e+1 
12 t:=t+1 
13 return z* 


state that has not been seen before. We can enforce this by keeping a tabu list which tracks the 
T most recently visited states. This forces the algorithm to explore new states, and increases the 
chances of escaping from local maxima. We continue to do this for up to Cmax steps (known as the 
“tabu tenure”). The pseudocode can be found in Algorithm 6. (If we set Cmax = 1, we get greedy hill 
climbing.) 

For example, consider what happens when tabu search reaches a hill top, æ+. At the next step, it 
will move to one of the neighbors of the peak, x:41 € nbr(x+), which will have a lower score. At the 
next step, it will move to the neighbor of the previous step, 42 € nbr(a41); the tabu list prevents 
it cycling back to a; (the peak), so it will be forced to pick a neighboring point at the same height or 
lower. It continues in this way, “circling” the peak, possibly being forced downhill to a lower level-set 
(an inverse basin flooding operation), until it finds a ridge that leads to a new peak, or until it 
exceeds a maximum number of non-improving moves. 

According to [RN10, p.123], tabu search increases the percentage of 8-queens problems that can 
be solved from 14% to 94%, although this variant takes an average of 21 steps for each successful 
instance and 64 steps for each failed instance. 


6.9.1.3 Random search 


A surprisingly effective strategy in problems where we know nothing about the objective is to use 
random search. In this approach, each iterate x,,1 is chosen uniformly at random from Æ. This 
should always be tried as a baseline. 

In [BB12], they applied this technique to the problem of hyper-parameter optimization for some 
ML models, where the objective is performance on a validation set. In their examples, the search 
space is continuous, © = (0, 1]?. It is easy to sample from this at random. The standard alternative 
approach is to quantize the space into a fixed set of values, and then to evaluate them all; this is 
known as grid search. (Of course, this is only feasible if the number of dimensions D is small.) 
They found that random search outperformed grid search. The intuitive reason for this is that many 
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Grid Layout Random Layout 


Q P, O 


Unimportant parameter 
Unimportant parameter 


Important parameter Important parameter 


Figure 6.12: Illustration of grid search (left) vs random search (right). From Figure 1 of [BB12]. Used with 
kind permission of James Bergstra. 


hyper-parameters do not make much difference to the objective function, as illustrated in Figure 6.12. 
Consequently it is a waste of time to place a fine grid along such unimportant dimensions. 

RS has also been used to optimize the parameters of MDP policies, where the objective has 
the form f(x) = E,.,, |R(T)] is the expected reward of trajectories generated by using a policy 
with parameters x. For policies with few free parameters, RS can outperform more sophisticated 
reinforcement learning methods described in Chapter 35, as shown in [MGR18]. In cases where 
the policy has a large number of parameters, it is sometimes possible to project them to a lower 
dimensional random subspace, and perform optimization (either grid search or random search) in 
this subspace [Li+ 18a]. 


= 6.9.2 Simulated annealing 


23 Simulated annealing [KJV83; LA87] is a stochastic local search algorithm (Section 6.9.1.1) 
29 that attempts to find the global minimum of a black-box function €(a), where €() is known as the 
30 energy function. The method works by converting the energy to an (unnormalized) probability 
31 distribution over states by defining p(x) = exp(—E€(a)), and then using a variant of the Metropolis 
32 Hastings algorithm to sample from a set of probability distributions, designed so that at the final 
33 step, the method samples from one of the modes of the distribution, i.e., it finds one of the most 
34 likely states, or lowest energy states. This approach can be used for both discrete and continuous 
35 optimization. See Section 12.9.1 for details. 


— 6.9.3 Evolutionary algorithms 


39 Stochastic local search (SLS) maintains a single “best guess” at each step, æ+. If we run this for 
40 T steps, and restart K times, the total cost is TK. A natural alternative is to maintain a set or 
41 population of K good candidates, S;, which we try to improve at each step. This is called an an 


evolutionary algorithm (EA). If we run this for T steps, it also takes TK time; however, it can 


43 often get better results than multi-restart SLS, since the search procedure explores more of the space 
44 in parallel, and information from different members of the population can be shared. Many versions 
45 of EA are possible, as we discuss below. 


Since EA algorithms draw inspiration from the biological process of evolution, they also borrow 
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24748552 | 24 31% »| 32752411 32748552 —= 32749052 | 


24415124 | 20 26% ~| 32752411 32752124 -—+| 322252124 


32752411 ee 24748552 pana 24752411 }—+| 24752411 


32543213 | 11 14% `| 24415124 24415411 |—+| 24415410] 


(a) (b) (c) (d) (e) 


Initial Population Fitness Function Selection Crossover Mutation 


Figure 6.13: Illustration of a genetic algorithm applied to the 8-queens problem. (a) Initial population of 4 
strings. (b) We rank the members of the population by fitness, and then compute their probability of mating. 
Here the integer numbers represent the number of nonattacking pairs of queens, so the global maximum has a 
value of 28. We pick an individual O with probability p(@) = L(@)/Z, where Z = $ oep L(A) sums the total 
fitness of the population. For example, we pick the first individual with probability 24/78 = 0.31, the second 
with probability 23/78 = 0.29, etc. In this example, we pick the first individual once, the second twice, the 
third one once, and the last one does not get to breed. (c) A split point on the “chromosome” of each parent is 
chosen at random. (d) The two parents swap their chromosome halves. (e) We can optionally apply pointwise 
mutation. From Figure 4.6 of [RN10]. Used with kind permission of Peter Norvig. 


a lot of its terminology. The fitness of a member of the population is the value of the objective 
function (possibly normalized across population members). The members of the population at step 
t+ 1 are called the offspring. These can be created by randomly choosing a parent from S; and 
applying a random mutation to it. This is like asexual reproduction. Alternatively we can create 
an offspring by choosing two parents from S;, and then combining them in some way to make a child, 
as in sexual reproduction; combining the parents is called recombination. (It is often followed by 
mutation.) 

The procedure by which parents are chosen is called the selection function. In truncation 
selection, each parent is chosen from the fittest K members of the population (known as the elite 
set). In tournament selection, each parent is the fittest out of K randomly chosen members. In 
fitness proportionate selection, also called roulette wheel selection, each parent is chosen 
with probability proportional to its fitness relative to the others. We can also “kill off” the oldest 
members of the population, and then select parents based on their fitness; this is called regularized 
evolution [Rea+19]). 

In addition to the selection rule for parents, we need to specify the recombination and mutation 
rules. There are many possible choices for these heuristics. We briefly mention a few of them below. 


e In a genetic algorithm (GA) [Gol89; Hol92], we use mutation and a particular recombination 
method based on crossover. To implement crossover, we assume each individual is represented 
as a vector of integers or binary numbers, by analogy to chromosomes. We pick a split point 
along the chromosome for each of the two chosen parents, and then swap the strings, as illustrated 
in Figure 6.13. 


e In genetic programming |[Koz92], we use use a tree-structured representation of individuals, 
instead of a bit string. This representation ensures that all crossovers result in valid children, 
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Figure 6.14: The 8-queens states corresponding to the first two parents in Figure 6.13(c) and their first child 
in Figure 6.13(d). We see that the encoding 32752411 means that the first queen is in row 3 (counting from 
the bottom left), the second queen is in row 2, etc. The shaded columns are lost in the crossover, but the 
unshaded columns are kept. From Figure 4.7 of [RN10]. Used with kind permission of Peter Norvig. 


(a) (b) 
(c) (a) 
Figure 6.15: Illustration of crossover operator in a genetic program. (a-b) the two parents, representing 


43 sin(x) + (£ +y)? and sin(x) + \/z? +y. The red circles denote the two crossover points. (c-d) the two 
children, representing sin(x) + (x°)? and sin(x) + Vx Fy Fy. Adapted from Figure 9.2 of [Mit97] 
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as illustrated in Figure 6.15. Genetic programming can be useful for finding good programs as 
well as other structured objects, such as neural networks. In evolutionary programming, the 
structure of the tree is fixed and only the numerical parameters are evolved. 


e In surrogate assisted EA, a surrogate function f (s) is used instead of the true objective function 
f(s) in order to speed up the evaluation of members of the population (see [Jin11] for a survey). 
This is similar to the use of response surface models in Bayesian optimization (Section 6.8), except 
it does not deal with the explore-exploit tradeoff. 


e Ina memetic algorithm [MC03], we combine mutation and recombination with standard local 
search. 


Evolutionary algorithms have been applied to a large number of applications, including training 
neural networks (this combination is known as neuroevolution [Sta+19]). An efficient JAX-based 
library for (neuro)-evolution can be found at https: //github.com/google/evojax. 


6.9.4 Estimation of distribution (EDA) algorithms 


EA methods maintain a population of good candidate solutions, which can be thought of as an 
implicit (nonparametric) density model over states with high fitness. [BC95] proposed to “remove 
the genetics from GAs”, by explicitly learning a probabilistic model over the configuration space that 
puts its mass on high scoring solutions. That is, the population becomes the set of parameters of a 
generative model, 0. 

One way to learn such as model is as follows. We start by creating a sample of K’ > K candidate 
solutions from the current model, S; = {£p ~ p(a|@;)}. We then rank the samples using the fitness 
function, and then pick the most promising subset S¥ of size K using a selection operator (this is 
known as truncation selection). Finally, we fit a new probabilistic model p(x|0:41) to S/ using 
maximum likelihood estimation. This is called the estimation of distribution or EDA algorithm 
(see e.g., [LL02; PSCP06; Hau+11; PHL12; Hu+12; San17; Bal17]). 

Note that EDA is equivalent to minimizing the cross-entropy between the empirical distribution 
defined by S;* and the model distribution p(æ|0++1). Thus EDA is related to the cross entropy 
method, as described in Section 6.9.5, although CEM usually assumes the special case where 
p(x|0) = N (x|, ©). EDA is also closely related to the EM algorithm, as discussed in [Bro+20a]. 

As a simple example, suppose the configuration space is bit strings of length D, and the fitness 
function is f(a) = See Za, where xa € {0,1} (this is called the one-max function in the EA 
literature). A simple probabilistic model for this is a fully factored model of the form p(æ|0) = 
I2 Ber(xalða). Using this model inside of DBO results in a method called univariate marginal 
distribution algorithm or UMDA. 

We can estimate the parameters of the Bernoulli model by setting 04 to the fraction of samples 
in S} that have bit d turned on. Alternatively, we can incrementally adjust the parameters. The 
population-based incremental learning (PBIL) algorithm [BC95] applies this idea to the factored 
Bernoulli model, resulting in the following update: 


bat+1 = (1—m)bae + ða, (6.235) 


where 04,4 = K S I (£k a = 1) is the MLE estimated from the K = |S¥| samples generated in the 
current iteration, and 7; is a learning rate. 
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Figure 6.16: Illustration of the BOA algorithm (EDA applied to a generative model structured as a Bayes 
net). Adapted from Figure 3 of [PHL 12]. 


It is straightforward to use more expressive probability models that capture dependencies between 


26 the parameters (these are known as building blocks in the EA literature). For example, in the case 
27 of real-valued parameters, we can use a multivariate Gaussian, p(x) = N (æ|u, ©). The resulting 
28 method is called the estimation of multivariate normal algorithm or EMNA, [LL02]. (See 
29 also Section 6.9.5.) 


For discrete random variables, it is natural to use probabilistic graphical models (Chapter 4) to 


31 capture dependencies between the variables. [BD97] learn a tree-structured graphical model using 
32 the Chow-Liu algorithm (Supplementary Section 30.1.1); [BJV97] is a special case of this where the 
33 graph is a tree. We can also learn more general graphical model structures (see e.g., [LL02]). We 
34 typically use a Bayes net (Section 4.2), since we can use ancestral sampling (Section 4.2.5) to easily 
35 generate samples; the resulting method is therefore called the Bayesian Optimization Algorithm 


(BOA) [PGCP00].* The hierarchical BOA (hBOA) algorithm [Pel05] extends this by using decision 
trees and decision graphs to represent the local CPTs in the Bayes net (as in [CHM97]), rather than 


38 using tables. In general, learning the structure of the probability model for use in EDA is called 
39 linkage learning, by analogy to how genes can be linked together if they can be co-inherited as a 
40 building block. 


We can also use deep generative models to represent the distribution over good candidates. For 


42 example, [CSF 16] use denoising autoencoders and NADE models (Section 22.2), [Ball7] uses a 
43 DNN regressor which is then inverted using gradient descent on the inputs, [PRG17] uses RBMs 


=" 4. This should not be confused with the Bayesian optimization methods we discuss in Section 6.8, that uses response 
46 surface modeling to model p(f(x)) rather than p(a*). 
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(Section 4.3.3.2), [GSM18] uses VAEs (Section 21.2), etc. Such models might take more data to 
fit (and therefore more function calls), but can potentially model the probability landscape more 
faithfully. (Whether that translates to better optimization performance is not clear, however.) 


6.9.5 Cross-entropy method 


The cross-entropy method [Rub97; RK04; Boe+05] is a special case of EDA (Section 6.9.4) in 
which the population is represented by a multivariate Gaussian. In particular, we set p,,, and 4441 
to the empirical mean and covariance of S#,,, which are the top K samples. This is closely related 
to the SMC algorithm for sampling rare events discussed in Section 13.6.4. 

The CEM is sometimes used for model-based RL (Section 35.4), since it is simple and can find 
reasonably good optima of multi-modal objectives. It is also sometimes used inside of Bayesian 
optimization (Section 6.8), to optimize the multi-modal acquisition function (see [BK10]). 


6.9.5.1 Differentiable CEM 


The differentiable CEM method of [AY19] replaces the top K operator with a soft, differentiable 
approximation, which allows the optimizer to be used as part of an end-to-end differentiable pipeline. 
For example, we can use this to create a differentiable model predictive control (MPC) algorithm 
(Section 35.4.1), as described in Section 35.4.5.2. 

The basic idea is as follows. Let S; = {a; ~ p(x|@;) : i = 1: K'} represent the current 
population, with fitness values vi = f(a,;). Let vf, be the K’th smallest value. In CEM, we 
compute the set of top K samples, Sf = {i : v4; > v% ae and then update the model based on these: 
O41 = argmaxg > jes, Pt(i) log p(#1,;|@), where pr (i) = I (i € Sž) /|S*|. In the differentiable version, 
we replace the sparse distribution p; with the “soft” dense distribution q; = II(p:;7, K), where 


I(p;7, K) = argmin—p'q—7H(q) s.t. 1'q = K (6.236) 
O<q<1 


projects the distribution p onto the polytope of distributions which sum to K. (Here H(q) = 
— 30, qi log(qi) + (1 — qi) log(1 — qi) is the entropy, and T > 0 is a temperature parameter.) This 
projection operator (and hence the whole DCEM algorithm) can be backpropagated through using 
implicit differentiation [AKZK19]. 


6.9.6 Evolutionary strategies 


Evolution strategies [Wic+14] are a form of distribution-based optimization in which the distribu- 
tion over the population is represented by a Gaussian, p(x|0;) (see e.g., [Sal+-17b]). Unlike CEM, 
the parameters are updated using gradient ascent applied to the expected value of the objective, 
rather than using MLE on a set of elite samples. More precisely, consider the smoothed objective 
L(A) = Epæjo) [f(x)]. We can use the REINFORCE estimator (Section 6.5.3) to compute the 
gradient of this objective as follows: 


VoL(9) = Enea) [f (£)V o log p(a|9)] (6.237) 


This can be approximated by drawing Monte Carlo samples. We discuss how to compute this gradient 
below. 
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Figure 6.17: Illustration of the CMA-ES method applied to a simple 2d function. The dots represent 
members of the population, and the dashed orange ellipse represents the multivariate Gaussian. From 
https: // en. wikipedia. org/ wiki/ CMA-ES. Used with kind permission of Wikipedia author Sentewolf. 


6.9.6.1 Natural evolutionary strategies 


If the probability model is in the exponential family, we can compute the natural gradient (Section 6.4), 
rather than the “vanilla” gradient, which can result in faster convergence. Such methods are called 
natural evolution strategies [Wie+14]. 


6.9.6.2 CMA-ES 


The CMA-ES method of [Han16], which stands for “covariance matrix adaptation evolution strategy” 
is a kind of NES. It is very similar to CEM except it updates the parameters in a special way. In 


3, particular, instead of computing the new mean and covariance using unweighted MLE on the elite 


set, we attach weights to the elite samples based on their rank. We then set the new mean to the 


a, weighted MLE of the elite set. 


The update equations for the covariance are more complex. In particular, “evolutionary paths” are 
also used to accumulate the search directions across successive generations, and these are used to 


2, update the covariance. It can be shown that the resulting updates approximate the natural gradient 


of £(@) without explicitly modeling the Fisher information matrix [OIl+17]. 
Figure 6.17 illustrates the method in action. 


6.10 Optimal Transport 


This section is written by Marco Cuturi. 


In this section, we focus on optimal transport theory, a set of tools that have been proposed, 


45 starting with work by [Mon81], to compare two probability distributions. We start from a simple 
46 example involving only matchings, and work from there towards various extensions. 
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6.10.1 Warm-up: Matching optimally two families of points 


Consider two families (x,,...,Xp) and (y1,...,Y¥n), each consisting in n > 1 distinct points taken 
from a set ¥. A matching between these two families is a bijective mapping that assigns to each 
point x; another point yj. Such an assignment can be encoded by pairing indices (i, j) € {1,...,n}? 
such that they define a permutation o in the symmetric group Sn. With that convention and given a 
permuation g, x; would be assigned to y,,, the o;-th element in the second family. 

Matchings costs. When matching a family with another, it is natural to consider the cost 
incurred when pairing any point x; with another point yj, for all possible pairs (i, j) € {1,...,n}?. 
For instance, x; might contain information on the current location of a taxi driver i, and y; that 
of a user j who has just requested a taxi; in that case, Cj; € R may quantify the cost (in terms of 
time, fuel or distance) required for taxi driver i to reach user j. Alternatively, x; could represent a 
vector of skills held by a job seeker i and yj a vector quantifying desirable skills associated with a 
job posting j; in that case C;; could quantify the numbers of hours required for worker 7 to carry out 
job j. We will assume without loss of generality that the values C;; are obtained by evaluating a 
cost function c: Æ x X — R on the pair (x;,y;), namely Ci; = c(x;,y;). In many applications of 
optimal transport, such cost functions have a geometric interpretation and are typically distance 
functions on ¥ as in Fig. 6.18, in which ¥ = R?, or as will be later discussed in Section 6.10.2.4. 

Least-cost Matchings. Equipped with a cost function c, the optimal matching (or assignment) 
problem is that of finding a permutation that reaches the smallest total cost, as defined by the 
function 


n 


min E(c) = 5 c(Xi, Yo,;)- (6.238) 


o 
i=l 


The optimal matching problem is arguably one of the simplest combinatorial optimization problems, 
tackled as early as the 19th century [JB65]. Although a naive enumeration of all permutations would 
require evaluating objective E a total of n! times, the Hungarian algorithm [Kuh55] was shown to 
provide the optimal solution in polynomial time [Mun57], and later refined to require in the worst 
case O(n?) operations. 


6.10.2 From Optimal Matchings to Kantorovich and Monge formulations 


The optimal matching problem is relevant to many applications, but it suffers from a few limitations. 
One could argue that most of the optimal transport literature arises from the necessity to overcome 
these limitations and extend (6.238) to more general settings. An obvious issue arises when the 
number of points available in both familites is not the same. The second limitation arises when 
considering a continuous setting, namely when trying to match (or morph) two probability densities, 
rather than families of atoms (discrete measures). 


6.10.2.1 Mass splitting 


Suppose again that all points x; and y; describe skills, respectively held by a worker 7 and needed for 
a task j to be fulfilled in a factory. Since finding a matching is equivalent to finding a permutation in 
{1,..., n}, problem (6.238) cannot handle cases in which the number of workers is larger (or smaller) 
than the number of tasks. More problematically, the assumption that every single task is indivisible, 
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Figure 6.18: (left) Matching a family of 5 points to another is equivalent to considering a permutation in 
{1,...,n}. When to each pair (xi, yj) E R? is associated a cost equal to the distance ||x: — y;||, the optimal 
matching problem involves finding a permutation o that minimizes ||xi — yo; || for i in {1,2,3,4,5}. (middle) 
The Kantorovich formulation of optimal transport generalizes optimal matchings, and arises when comparing 
discrete measures, that is families of weighted points that do not necessarily share the same size but do 
share the same total mass. The relevant variable is a matrix P of size n x m, which must satisfy row-sum 
and column-sum constraints, and which minimizes its dot product with matrix Ci;. (right) another direct 
extension of the matching problem lies when, intuitively, the number n of points that is described is such that 
the considered measures become continuous densities. In that setting, and unlike the Kantorovich setting, the 
goal is to seek a map T : X + X which, to any point x in the support of the input measure u is associated 
a point y = T(x) in the support of v. The push-forward constraint Tyu = v ensures that v is recovered by 
applying map T to all points in the support of u; the optimal map T* is that which minimizes the distance 
between x and T(x), averaged over u. 


or that workers are only able to dedicate themselves to a single task, is hardly realistic. Indeed, 


= certain tasks may require more (or less) dedication than that provided by a single worker, whereas 
=“ some workers may only be able to work part-time, or, on the contrary, be willing to put extra hours. 
= The rigid machinery of permutations falls short of handling such cases, since permutations are by 
= definition one-to-one associations. The Kantorovich formulation allows for mass-splitting, the idea 
= that the effort provided by a worker or needed to complete a given task can be split. In practice, 
= to each of the n workers is associated, in addition to x;, a positive number a; > 0. That number 
== represents the amount of time worker 7 is able to provide. Similarly, we introduce numbers b; > 0 
= describing the amount of time needed to carry out each of the m tasks (n and m do not necessarily 
= coincide). Worker i is therefore described as a pair (a;, x;), mathematically equivalent to a weighted 
= Dirac measure ajdx,. The overall workforce available to the factory is described as a discrete measure 
= $; aidx,, whereas its tasks are described in }), bjéy,. If one assumes further that the factory has 
= a balanced workload, namely that >7,a; = a b;, then the Kantorovich [Kan42] formulation of 
= optimal transport is: 


OT g(a, b) ê min (P,C) £ X PCy. (6.239) 
PERD*™ ,Pln=a,PT1m=b i 


43 The interpretation behind such matrices is simple: each coefficient P,;; describes an allocation of 
44 time for worker i to spend on task j. The i-th row-sum must be equal to the total a; for the time 
45 constraint of worker 7 to be satisfied, whereas the j-th column-sum must be equal to b,, reflecting 
46 that the time needed to complete task 7 has been budgeted. 
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6.10.2.2 Monge formulation and optimal push-forward maps 


By introducing mass-splitting, the Kantorovich formulation of optimal transport allows for a far 
more general comparison between discrete measures of different sizes and weights (middle plot of 
Fig. 6.18). Naturally, this flexibility comes with a downside: one can no longer associate to each point 
x; another point y; to which it is uniquely associated, as was the case with the classical matching 
problem. Interestingly, this property can be recovered in the limit where the measures become 
densities. Indeed, the Monge [Mon81] formulation of optimal transport allows to recover precisely 
that property, on the condition (loosely speaking) that measure u admits a density. In that setting, 
the analogous mathematical object guaranteeing that u is mapped onto v is that of push-forward 
maps morphing p to v, namely maps T such that for any measurable set A C 4, u(T~1(A)) = v(A). 
When T is differentiable, and u,v have densities p and q w.r.t. the Lebesgue measure in R?, this 
statement is equivalent, thanks to the change of variables formula, to ensuring almost everywhere 
that: 


q(T (x)) = p(x)|Jr(a)], (6.240) 


where |J7(a)| stands for the determinant of the Jacobian matrix of T evaluated at x. 
Writing Tyu = v when T does satisfy these conditions, the Monge [Mon81] problem consists in 
finding the best map T that minimizes the average cost between x and its displacement T(x), 


inf | [ e700) nla) (6.241) 


T is therefore a map that pushes forward u to v globally, but which results, on average, in the smallest 
average cost. While very intuitive, the Monge problem turns out to be extremely difficult to solve in 
practice, since it is non-convex. Indeed, one can easily check that the constraint {Tu = v} is not 
convex, since one can easily find counter-examples for which Tyu = v and Tyv yet ($T+43T')yw Av. 
Luckily, Kantorovich’s approach also works for continuous measures, and yields a comparatively 
much simpler linear program. 


6.10.2.3 Kantorovich formulation 


The Kantovorich problem (6.239) can also be extended to a continuous setting: Instead of optimizing 
over a subset of matrices in R”*™, consider II(,v), the subset of joint probability distributions 
P(X x X) with marginals u and v, namely 


T(u,v) = {r E€ P(X?) : VAC X, T(A x X) = u(A) and T(¥ x A) = v(A)}. (6.242) 


Note that II(u,v) is not empty since it always contains the product measure u ® v. With this 
definition, the continuous formulation of (6.239) can be obtained as 


OT.(u,v)= inf f cdr. (6.243) 
well (u,v) J x2 
Notice that (6.243) subsumes directly (6.239), since one can check that that they coincide when 
u,v are discrete measures, with respective probability weights a,b and locations (x1, ...,Xn) and 
(yı, ere i Vas 
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6.10.2.4 Wasserstein distances 


When c is equal to a metric d exponentiated by an integer, the optimal value of the Kantorovich 
problem is called the Wasserstein distance between u and v: 


1/p 
wne ( it f deyr ary) (6.244) 


While the symmetry and the fact that W,(u,v) = 0 > u = v are relatively easy to prove provided d 
is a metric, proving the triangle inequality is slightly more challenging, and builds on a result known 
as the gluing lemma ([Vil08, p.23]). The p-th power of Wp(u, v) is often abbreviated as WẸ (u, v). 


6.10.3 Solving optimal transport 
6.10.3.1 Duality and cost concavity 


Both (6.239) and (6.243) are linear programs: their constraints and objective functions only involve 
summations. In that sense they admit a dual formulation (here, again, (6.246) subsumes (6.245)): 


max fTa+ gb (6.245) 
fER” ,gER™” 
fOg<C 
sup frase f g dv (6.246) 
fPg<cJX x 


where the sign © denotes tensor addition for vectors, f © g = [f; + g,];;, or functions, f Og: x,y 


26 f(x) +g(y). In other words, the dual problem looks for a pair of vectors (or functions) that attain 
27 the highest possible expectation when summed against a and b (or integrated against u,v), pending 
28 the constraint that they do not differ too much across points x,y, as measured by c. 


The dual problems in (6.239) and (6.243) have two variables. Focusing on the continuous formula- 


30 tion, a closer inspection shows that it is possible, given a function f for the first measure, to compute 
31 the best possible candidate for function g. That function g should be as large as possible, yet satisfy 
32 the constraint that g(y) < c(x, y) — f(x) for all x,y, making 


Vy € X, f(y) ê inf c(x,y) — f(x), (6.247) 


*° the optimal choice. f is called the c-transform of f. Naturally, one may choose to start instead from 
= g, to define an alternative c-transform: 


Vx € ¥,g(x) Ê inf c(x,y) — g(y). (6.248) 


— Since these transformations can only improve solutions, one may even think of applying alternatively 


these transformations to an arbitrary f, to define f, f and so on. One can show, however, that. this 
has little interest, since 


Fae (6.249) 
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This remark allows, nonetheless, to narrow down the set of candidate functions to those that have 
already undergone such transformations. This reasoning yields the so-called set of c-concave functions, 
F. = {f |3g : X > R, f = 9}, which can be shown, equivalently, to be the set of functions f such 


that f = f. One can therefore focus our attention to c-concave functions to solve (6.246) using a 
so-called semi-dual formulation, 


ap a ee ie Tav (6.250) 


FEF: 


Going from (6.246) to (6.250), we have removed a dual variable g and narrowed down the feasible set 
to Fe, at the cost of introducing the highly non-linear transform f. This reformulation is, however, 
very useful, in the sense that it allows to restrict our attention on c-concave functions, notably for 
two important classes of cost functions c: distances and squared-Euclidean norms. 


6.10.3.2 Kantorovich-Rubinstein duality and Lipschitz potentials 


A striking result illustrating the interest of c-concavity is provided when c is a metric d, namely when 
p = 1 in (6.244). In that case, one can prove (exploiting notably the triangle inequality of the d) 
that a d-concave function f is 1-Lipschitz (one has |f (x) — f(y)| < d(x,y) for any x,y) and such 
that f = — f. This result translates therefore in the following identity: 


Wı(u, v) = sup Í f (du — dv). (6.251) 
f€1-Lipschitz JX 


This result has numerous practical applications. This supremum over 1-Lipschitz functions can be 
efficiently approximated using Wavelet coefficients of densities in low-dimensions [SJ08], or heuristically 
in more general cases by training neural networks parameterized to be 1-Lipschitz [ACB17] using 
ReLU activation functions, and bounds on the entries of the weight matrices. 


6.10.3.3 Monge maps as gradients of convex functions: the Brenier theorem 


Another application of c-concavity lies in the case c(x, y) = Hx — y|||?, which corresponds, up to the 
factor $, to the squared Wə distance used between densities in an Euclidean space. The remarkable 
result, shown first by [Bre91], is that the Monge map solving (6.241) between two measures for that 
cost (taken for granted p is regular enough, here assumed to have a density w.r.t. Lebesgue measure) 
exists and is necessarily the gradient of a convex function. In loose terms, one can show that 


T= arg „min I $llx — T(x)||3 w(x). (6.252) 


‘Ty =v 


exists, and is the gradient of a convex function u : R — R, namely T* = Vu. Conversely, for any 
convex function u, the optimal transport map between u and the displacement Vuxyp is necessarily 
equal to Vu. 

We provide a sketch of the proof: one can always exploit, for any reasonable cost function c 
(e.g. lower bounded and lower semi continuous), primal-dual relationships: Consider an optimal 
coupling P* for (6.243), as well as an optimal c-concave dual function f* for (6.250). This implies 
in particular that (f*,g* = f*) is optimal for (6.246). Complementary slackness conditions for this 
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pair of linear programs imply that if xo, yo is in the support of P*, then necessarily (and sufficiently) 
f* (xo) + f* (y0) = c(X0, yo). Suppose therefore that xo, yo is indeed in the support of P*. From the 
equality f*(xo) + f*(yo) = ¢(Xo0, yo) one can trivially obtain that f*(yo) = c(xo, yo) — f* (xo). Yet, 
recall also that, by definition, f*(yo) = inf, c(x, yo) — f*(x). Therefore, x9 has the special property 
that it minimizes x > c(x, yo) — f*(x). If, at this point, one recalls that c is assumed in this section 


to be c(x, y) = $||x — y|||?, one has therefore that xo verifies 
Xo € argmin 4||x — yoll? — f*(x). (6.253) 


Assuming f* is differentiable, which one can prove by c-concavity, this yields the identity 
yo — Xo — Vf* (xo) = 0 > yo = Xo — V f* (xo) = V ($l “(PP - f*) (xo). (6.254) 


Therefore, if (xo, yo) is in the support of P*, yo is uniquely determined, which proves P* is in fact a 
Monge map “disguised” as a coupling, namely 


P* = (Id, V ($I? - ye. (6.255) 


The end of the proof can be worked out as follows: For any function h : ¥ — R, one can show, using 
the definitions of c-transforms and the Legendre transform, that 4||- |? — h is convex if and only if h 
is c-concave. An intermediate step in that proof relies on showing that 4|| - ||? — h is equal to the 
Legendre transform of $]|- ||? —h. The function $]| - ||? — f* above is therefore convex, by c-concavity 
of f*, and the optimal transport map is itself the gradient of a convex function. 

Knowing that an optimal transport map for the squared-Euclidean cost is necessarily the gradient 
of a convex function can prove very useful to solve (6.250). Indeed, this knowledge can be leveraged 
to restrict estimation to relevant families of functions, namely gradients of input-convex neural 


27 networks[AXK17], as proposed in [Mak-+20] or [Kor+20], as well as arbitrary convex functions with 
28 desirable smoothness and strong-convexity constants [PdC20]. 


30 6.10.3.4 Closed forms for univariate and Gaussian distributions 


2, Many metrics between probability distributions have closed form expressions for simple cases. The 


Wasserstein distance is no exception, and can be computed in close form in two important scenarios. 
When distributions are univariate and the cost c(x,y) is either a convex function of the difference 
x — y, or when 0c/0x0y < 0 a.e., then the Wasserstein distance is essentially a comparison between 
the quantile functions of u and v. Recall that for a measure p, its quantile function Q, is a function 


3, that takes values in [0,1] and is valued in the support of p, and corresponds to the (generalized) 


inverse map of F,, the cumulative distribution function (cdf) of p. With these notations, one has 


ag that 


OT.(n,¥) = f, (Gul); Golo) a (6.256) 


43 In particular, when c is x,y > |x — y| then OT.(y,v) corresponds to the Kolmogorov-Smirnov 
44 statistic, namely the area between the cdf of u and that of v. If c is x,y + (x — y)?, we recover 
45 simply the squared-Euclidean norm between the quantile functions of u and v. Note finally that the 
46 Monge map is also available in closed form, and is equal to Q, o Fy. 
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The second closed form applies to so-called elliptically contoured distributions, chiefly among 
them Gaussian multivariate distributions|Gel90]. For two Gaussians NV (m1, £1) and N (mg, 2) their 
2-Wasserstein distance decomposes as 


W3(N (m1, 51), (mg, X2) = |m = mp||” + B?(d1, D2) (6.257) 


where the Bures metric B reads: 


1 1\ 2 
B? (£1, Xe) = trace yy + dg —2 (=?=23? ) : (6.258) 


Notice in particular that these quantities are well-defined even when the covariance matrices are 
not invertible, and that they collapse to the distance between means as both covariances become 0. 
When the first covariance matrix is invertible, one has that the optimal Monge map is given by 


ayol Teg. d 
T xm A(x — mı) +m, where A £ X} ? (2722) yee (6.259) 


It is easy to show that T* is indeed optimal: The fact that T} (m1, £1) = N (m2, £2) follows from 
the knowledge that the affine push-forward of a Gaussian is another Gaussian. Here T is designed 
to push precisely the first Gaussian onto the second (and A designed to recover random variables 
with variance Xə when starting from random variables with variance X1). The optimality of T can 
be recovered by simply noticing that is the gradient of a convex quadratic form, since A is positive 
definite, and closing this proof using the Brenier theorem above. 


6.10.3.5 Exact evaluation using linear program solvers 


We have hinted, using duality and c-concavity, that methods based on stochastic optimization over 
1-Lipschitz or convex neural networks can be employed to estimate Wasserstein distances when c is 
the Euclidean distance or its square. These approaches are, however, non-convex and can only reach 
local optima. Apart from these two cases, and the closed forms provided above, the only reliable 
approach to compute Wasserstein distances appears when both u and v are discrete measures: In 
that case, one can instantiate and solve the discrete (6.239) problem, or its dual (6.245) formulation. 
The primal problem is a canonical example of network flow problems, and can be solved with the 
network-simplex method in O(nm(n + m) log(n + m)) complexity [AMO88], or, alternatively, with 
the comparable auction algorithm [BC89]. These approaches suffer from computational limitations: 
their cubic cost is intractable for large scale scenarios; their combinatorial flavor makes it harder to 
solve to parallelize simultaneously the computation of multiple optimal transport problems with a 
common cost matrix C. 

An altogether different issue, arising from statistics, should further discourage users from using 
these LP formulations, notably in high-dimensional settings. Indeed, the bottleneck practitioners will 
most likely encounter when using (6.239) is that, in most scenarios, their goal will be to approximate 
the distance between two continuous measures u,v using only i.i.d samples contained in empirical 
measures jin, Mp. Using (6.239) to approximate the corresponding (6.243) is doomed to fail, as 
various results [FG15] have shown in relevant settings (notably for measures in R4) that the sample 
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complexity of the estimator provided by (6.239) to approximate (6.243) is of order 1/n'/4. In other 
words, the gap between W2(y,v7) and W2(Ân, Ôn) is large on expectation, decreases extremely slowly 
as n increases in high dimensions, and solving exactly (6.243) between these samples is compute 
power that is mostly wasted on overfitting. To address this curse of dimensionality, it is therefore 
extremely important in practice to approach (6.243) using a more careful strategy, one that involves 
regularizations that can leverage prior assumptions on p and v. While all approaches outlined above 
using neural networks can be interpreted under this light, we focus in the following on a specific 
approach that results in a convex problem that is relatively simple to implement, embarassingly 
parallel and with quadratic complexity. 


6.10.3.6 Obtaining smoothness using entropic regularization 


A computational approach to speed-up the resolution of (6.239) was proposed in [Cut13], building 
on earlier contributions [Wil69; KY94] and a filiation to the Schrödinger bridge problem in the 
special case where c = d? [Léo14]. The idea rests upon regularizing the transportation cost by the 
Kullback-Leibler divergence of the coupling to the product measure of u,v, 


We(u,v) = inf d(x,y)? da(x,y) + yDxx (alu @ v). (6.260) 
well (u,v) J x2 
When instantiated on discrete measures, this problem is equivalent to the following 7-strongly convex 
problem on the set of transportation matrices (which should be compared to (6.239)) 


OT¢,,(a,b) = min (P,C) £ J PijCig — YH(P) +7 (H(a) + H(b)) , (6.261) 
PERD*™ P1m=a,PT1,=b L 


26 which is itself equivalent to the following dual problem (which should be compared to (6.245)) 


OTo„(a,b) = ,_max, frat get b— yell)" Ke8/) + y(1+H(a) +H(b)) (6.262) 


30 and K £ e70% is the elementwise exponential of —C/y. This regularization has several benefits. 
31 Primal-dual relationships show an explicit link between the (unique) solution P% and a pair of optimal 
32 dual variables (f*, g*) as 


> f a 
Px = diag(e*/7) Kdiag(e8/7) (6.263) 


Problem(6.262) can be solved using a fairly simple strategy that has proved very sturdy in practice: 


», a simple block-coordinate ascent (optimizing alternatively the objective in f and then g), resulting in 


the famous Sinkhorn algorithm [Sin67], here expressed with log-sum-exp updates, starting from an 


2 arbitrary initialization for g, to carry out these two updates sequentially, until they converge: 


f + yloga— ylog Ke8/7 g< ylogb— ylog KT ef/7 (6.264) 


The convergence of this algorithm has been amply studied (see [CK21] and references therein). 


43 Convergence is naturally slower as y decreases, reflecting the hardness of approaching LP solutions, 
44 as studied in [AWR17]. This regularization also has statistical benefits since, as argued in [Gen+19], 
45 the sample complexity of the regularized Wasserstein distance improves to a O(1/,/n) regime, with, 
46 however, a constant in 1/74/? that deteriorates as dimension grows. 
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6.11. SUBMODULAR OPTIMIZATION 


6.11 Submodular optimization 


This section was written by Jeff Bilmes. 


This section provides a brief overview of submodularity in machine learning.” Submodularity has 
an extremely simple definition. However, the “simplest things are often the most complicated to 
understand fully” [Sam74], and while submodularity has been studied extensively over the years, it 
continues to yield new and surprising insights and properties, some of which are extremely relevant 
to data science, machine learning, and artificial intelligence. A submodular function operates on 
subsets of some finite ground set, V. Finding a guaranteed good subset of V would ordinarily 
require an amount of computation exponential in the size of V. Submodular functions, however, 
have certain properties that make optimization either tractable or approximable where otherwise 
neither would be possible. The properties are quite natural, however, so submodular functions 
are both flexible and widely applicable to real problems. Submodularity involves an intuitive 
and natural diminishing returns property, stating that adding an element to a smaller set helps 
more than adding it to a larger set. Like convexity, submodularity allows one to efficiently find 
provably optimal or near-optimal solutions. In contrast to convexity, however, where little regarding 
maximization is guaranteed, submodular functions can be both minimized and (approximately) 
maximized. Submodular maximization and minimization, however, require very different algorithmic 
solutions and have quite different applications. It is sometimes said that submodular functions are 
a discrete form of convexity. This is not quite true, as submodular functions are like both convex 
and concave functions, but also have properties that are similar simultaneously to both convex and 
concave functions at the same time, but then some properties of submodularity are like neither 
convexity nor concavity. Convexity and concavity, for example, can be conveyed even as univariate 
functions. This is impossible for submodularity, as submodular functions are defined based only on 
the response of the function to changes amongst different variables in a multidimensional discrete 
space. 


6.11.1 Intuition, Examples, and Background 


Let us define a set function f : 2” — R as one that assigns a value to every subset of V. The 
notation 2” is the power set of V, and has size 2!”! which means that f lives in space R?” — i.e., 
since there are 2” possible subsets of V, f can return 2” distinct values. We use the notation X +v 
as shorthand for X U {v}. Also, the value of an element in a given context is so widely used a 
concept, we have a special notation for it — the incremental value gain of v in the context if X is 
defined as f(v|X) = f(X +v) — f(X). Thus, while f(v) is the value of element v, f(v|X) is the 
value of element v if you already have X. We also define the gain of set X in the context of Y as 
f(XIY) = f(X UY) — f(Y). 


6.11.1.1 Coffee, Lemon, Milk, Tea 


As a simple example, will explore the manner in which the value of everyday items may interact and 
combine, namely coffee, lemon, milk, and tea. Consider the value relationships amongst the four 
items coffee (c), lemon (l), milk (m), and tea (t) as shown in Figure 6.19. Suppose you just woke up, 


5. A greatly extended version of the material in this section may be found at [Bil22]. 
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Figure 6.19: The value relationships between coffee (c), lemon (1), milk (m), and tea (t). On the left, we first 
see a simple square showing the relationships between coffee and tea, and see that they are substitutive (or 
submodular). In this, and all of the shapes, the vertex label set is indicated in curly braces and the value at 
that vertex is a blue integer in a box. We next see a three-dimensional cube that adds lemon to coffee and tea 
set. We see that tea and lemon are complementary (supermodular) but coffee and lemon are additive (modular, 
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or independent). We next see a four-dimensional hypercube (tesseract) showing all of the value relationships 
described in the text. The four-dimensional hypercube is also shown as a lattice (on the right) showing the 
same relationships as well as two (red and green, also shown in the tesseract) of the eight three-dimensional 


cubes contained within. 


and there is a function f : 2” — R that provides the average valuation for any subset of the items in 
V where V = {c,l,m,t}. You can think of this function as giving the average price a typical person 
would be willing to pay for any subset of items. Since nothing should cost nothing, we would expect 
that f(Ø) =0. Clearly, one needs either coffee or tea in the morning, so f(c) > 0 and f(t) > 0, and 
coffee is usually more expensive than tea, so that f(c) > f(t) pound for pound. Also more items cost 
more, so that, for example, 0 < f(c) < f(c,m) < f(c,m,t) < f(¢l,m,t). Thus, the function f is 


30 strictly monotone, or f(X) < f(Y) whenever X CY. 


The next thing we note is that coffee and tea may substitute for each other — they both have 
They are mutually redundant, and they decrease each other’s 


32 the same effect, waking you up. 
33 value since once you have had a cup of coffee, a cup of tea is less necessary and less desirable. Thus, 


f(c,t) < f(c) + f(t), which is known as a subadditive relationship, the whole is less than the sum 


35 of the parts. On the other hand, some items complement each other. For example, milk and coffee 
36 are better combined together than when both are considered in isolation, or f(m,c) > f(m) + f(c), 


a superadditive relationship, the whole is more than the sum of the parts. A few of the items 


38 do not affect each others’ price. For example, lemon and milk cost the same together as apart, so 
39 f(l,m) = f(D) + f(m), an additive or modular relationship — such a relationship is perhaps midway 
40 between a subadditive and a superadditive relationship, and can be seen as a form of independence. 


Things become more interesting when we consider three or more items together. For example, 


42 once you have tea, lemon becomes less valuable when you acquire milk since there might be those 
43 that prefer milk to lemon in their tea. Similarly, milk becomes less valuable once you have acquired 
44 lemon since there are those who prefer lemon in their tea to milk. So, once you have tea, lemon and 
45 milk are substitutive, you would never use both as the lemon would only curdle the milk. These 
46 are submodular relationships, f(I|m,t) < f(I|t) and f(mll,t) < f(m|t) each of which implies that 
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f(l,t) + f(m,t) > f(l,m,t) + f(t). The value of lemon (respectively milk) with tea decreases in the 
larger context of having milk (respectively lemon) with tea, typical of submodular relationships. 

Not all of the items are in a submodular relationship, as sometimes the presence of an item can 
increase the value of another item. For example, once you have milk, then tea becomes still more 
valuable when you also acquire lemon, since tea with the choice of either lemon or milk is more 
valuable than tea with the option only of milk. Similarity, once you have milk, lemon becomes more 
valuable when you acquire tea, since lemon with milk alone is not nearly as valuable as lemon with 
tea, even if milk is at hand. This means that f(t|l,m) > f(t|m) and f(J|t,m) > f(m) implying 
f(i,m) + f(m,t) < f(l,m,t) + f(m). These are known as supermodular relationships, where the 
value increases as the context increases. 

We have asked for a set of relationships amongst various subsets of the four items V = {c,l, m, t}, 
Is there a function that offers a value to each X C V that satisfies all of the above relationships? 
Figure 6.19 in fact shows such a function. On the left, we see a two-dimensional square whose vertices 
indicate the values over subsets of {c,t} and we can quickly verify that the sum of the blue boxes on 
north-west (corresponding to f({c})) and south-east corners (corresponding to f({t})) is greater than 
the sum of the north-east and south-west corners, expressing the required submodular relationship. 
Next on the right is a three-dimensional cube that adds the relationship with lemon. Now we have 
six squares and we see that the values at each of the vertices all satisfy the above requirements 

— we verify this by considering the valuations at the four corners of every one of the six faces of 
the cube. Since |V| = 4 we need a four-dimensional hypercube to show all values and this may be 
shown in two ways. It is first shown as a tesseract, a well-known three-dimensional projection of a 
four-dimensional hypercube. In the figure, all vertices are labeled both with subsets of V as well as 
the function value f(X) as the blue number in a box. The figure on the right shows a lattice version 
of the four-dimensional hypercube, where corresponding three-dimensional cubes are shown in green 
and red. 

We thus see that a set function is defined for all subsets of a ground set, and that they correspond 
to valuations at all vertices of the hypercube. For the particular function over valuations of subsets 
of coffee, lemon, milk, and tea, we have seen submodular, supermodular, and modular relationships 
all in one function. Therefore, the overall function f defined in Figure 6.19 is neither submodular, 
supermodular, nor modular. For combinatorial auctions, there is often a desire to have a diversity 
of such manners of relationships [LLN06] — representation of these relationships can be handled 
by a difference of submodular functions [NB05; [B12] or a sum of a submodular and supermodular 
function [BB18] (further described below). In machine learning, however, most of the time we are 
interested in functions that are submodular (or modular, or supermodular) everywhere. 


6.11.2 Submodular Basic Definitions 


For a function to be submodular, it must satisfy the submodular relationship for all subsets. We 
arrive at the following definition. 


Definition 6.11.1 (Submodular Function). A given set function f : 2” — R is submodular if for 
all X,Y CV, we have the following inequality: 


F(X) + f(Y) = f(X UY) + f(XNY) (6.265) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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There are also many other equivalent definitions of submodularity [Bil22] some of which are more 
intuitive and easier to understand. For example, submodular functions are those set functions that 
satisfy the property of diminishing returns. If we think of a function f(X) as measuring the value of 
a set X that is a subset of a larger set of data items X C V, then the submodular property means 
that the incremental “value” of adding a data item v to set X decreases as the size of X grows. This 
gives us a second classic definition of submodularity. 


Definition 6.11.2 (Submodular Function via Diminishing Returns). A given set function f : 2V —> R 
is submodular if for all X,Y C V, where X CY and for allv ¢ Y, we have the following inequality: 


F(X +v) - f(X) > fY +v) - fY) (6.266) 


The property that the incremental value of lemon with tea is less than the incremental value 
of lemon once milk is already in the tea is equivalent to Equation 6.265 if we set X = {m, t} and 
Y = {i,t} (ie. f(m,t) + f(t) > f(l,m,t) + f(t). It is naturally also equivalent to Equation 6.266 
if we set X = {t}, Y = {m,t}, and with v = l (ie., f(m, t) < f(Ut)). 

There are many functions that are submodular, one famous one being Shannon entropy seen 
as a function of subsets of random variables. We first point out that there are non-negative (i.e., 
f(A) > 0,VA), monotone non-decreasing (i.e., f(A) < f(B) whenever A C B) submodular functions 
that are not entropic [Yeu91b; ZY97; ZY98], so submodularity is not just a trivial restatement 
of the class of entropy functions. When a function is monotone non-decreasing, submodular, and 
normalized so that f(Ø) = 0, it is often referred to as a polymatroid function. Thus, while the 
entropy function is a polymatroid function, it does not encompass all polymatroid functions even 
though all polymatroid functions satisfy the properties Claude Shannon mentioned as being natural 
for an “information” function (see Section 6.11.7). 

A function f is supermodular if and only if —f is submodular. If a function is both submodular 


27 and supermodular, it is known as a modular function. It is always the case that modular functions 
28 may take the form of a vector-scalar pair (m,c) where m : 2” — R and where c € R is a constant, 
29 and where for any A C V, we have that m(A) = c+), ca Mw. If the modular function is normalized, 


so that m(Ø) = 0, then c = 0 and the modular function can be seen simply as a vector m € RY. 


31 Hence, we sometimes say that the modular function x € RV offers a value for set A as the partial 


sum (A) = X „ea £(v). Many combinatorial problems use modular functions as objectives. For 


33 example, the graph cut problem uses modular function defined over the edges, judges a cut in a 
34 graph as the modular function applied to the edges that comprise the cut. 


As can be seen from the above, and by considering Figure 6.19, a submodular function, and in 


36 fact any set function, f : 2” — R can be seen as a function defined only on the vertices of the 
37 n-dimensional unit hypercube [0,1]”. Given any set X C V, we define 1x € {0,1}" to be the 
38 characteristic vector of set X defined as 1x(v) = 1 if v € X and 1x(v) = 0 otherwise. This gives us 
39 a way to map from any set X C V to a binary vector 1x. We also see that 1x is itself a modular 


function since 1x € {0,1} CRY. 
Submodular functions share a number of properties in common with both convex and concave 
functions [Lov83], including wide applicability, generality, multiple representations, and closure 


43 under a number of common operators (including mixtures, truncation, complementation, and certain 
44 convolutions). There is one important submodular closure property that we state here — that is that 


if we take a non-negative weighted (or conical) combinations of submodular functions, we preserve 


46 submodularity. In other words, if we have a set of k submodular functions, f; : 2” > R, i € [k], and 
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we form f(X) = De wi fi( X) where w; > 0 for all i, then Definition 6.11.1 immediately implies that 
f is also submodular. When we consider Definition 6.11.1, we see that submodular functions live in a 
cone in 2”-dimensional space defined by the intersection of an exponential number of half-spaces each 
one of which is defined by one of the inequalities of the form f(X)+ fY) > f(X UY) + f(X NY). 
Each submodular function is therefore a point in that cone. It is therefore not surprising that taking 
conical combinations of such points stays within this cone. 


6.11.3 Example Submodular Functions 


As mentioned above, there are many functions that are submodular besides entropy. Perhaps the 
simplest such function is f(A) = \/]A] which is the composition of the square-root function (which is 
concave) with the cardinality |A| of the set A. The gain function is f(A +v) — f(A) = Vk + 1 — vk 
if |A| = k, which we know to be a decreasing in k, thus establishing the submodularity of f. In 
fact, if ¢ : R > R is any concave function, then f(A) = (|A|) will be submodular for the same 
reason.° Generalizing this slightly further, a function defined as f(A) = 6()04e4™(a)) is also 
submodular, whenever m(a) > 0 for all a € V. This yields a composition of a concave function 
with a modular function f(A) = ¢(m(A)) since X ca m(a) = m(A). We may take sums of 
such functions as well as add a final modular function without losing submodularity, leading to 
F(A) = Puey Pull aca Mu(@)) +X aca Mla) where u can be a distinct concave function for 
each u, Mu,q is a non-negative real value for all u and a, and m4 (a) is an arbitrary real number. 
Therefore, f(A) = Jueu u(mu(A)) +m (A) where m, is a u-specific non-negative modular function 
and m+ is an arbitrary modular function. Such functions are sometimes known as feature-based 
submodular functions [BB17| because U can be a set of non-negative features (in the machine-learning 
“bag-of-words” sense) and this function measures a form of dispersion over A as determined by the 
set of features U. 

A function such as f(A) = J uey ¢u(mu(A)) tends to award high diversity to a set A that has a 
high valuation by a distinct set of the features U. The reason is that, due to the concave nature of 
gy, any addition to the argument m,,(A) by adding, say, v to A would diminish as A gets larger. In 
order to produce a set larger than A that has a much larger valuation, one must use a feature u’ 4 u 
that has not yet diminished as much. 

Facility location is another well-known submodular function — perhaps an appropriate nickname 
would be the “k-means of submodular functions,” due to its applicability, utility, ease-of-use (it 
needs only an affinity matrix), and similarity to k-medoids problems. The facility location function 
is defined using an affinity matrix as follows: f(A) = $ „ey maxac a sim(a, v) where sim(a, v) is a 
non-negative measure of the affinity (or similarity) between element a and v. Here, every element 
v E€ V must have a representative within the set A and the representative for each v € V is chosen 
to be the element a € A most similar to v. This function is also a form of dispersion or diversity 
function because, in order to maximize it, every element v € V must have some element similar to 
it in A. The overall score is then the sum of the similarity between each element v € V and v’s 
representative. This function is monotone (since as A includes more elements to become B D A, it is 
possible only to find an element in B more similar to a given v than an element in A). 

While the facility location looks quite different from a feature based function, it is possible 


6. While we will not be extensively discussing supermodular functions in this section, f(A) = ¢(|A|) is supermodular 
for any convex function ¢. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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to precisely represent any facility location function with a feature based function. Consider 
just MaXac A Za and, without loss of generality, assume that 0 < zı < a2 < +--+: < £n. Then 
maxacA Va = oy, Yimin(|AN {i,i+1,...,n}|, 1) where y; = x; — xj_1 and we set xo = 0. We note 
that this is a sum of weighted concave composed with modular functions since min(a, 1) is concave 
in a, and |AN {i,i + 1,...,n}| is a modular function in A. Thus, the facility location function, a 
sum of these, is merely a feature based function. 

Feature based functions, in fact, are quite expressive, and can be used to represent many different 
submodular functions including set cover and graph-based functions. For example, we can define a set 
cover function, given a set of sets {Uy }yev, via f(X) = lUsex U, |. If f(X) = |U| where U = Uey Vv 
then X indexes a set that fully covers U. This can also be represented as f(X) = ) „ey min(1,m,(X)) 
where m,,(X) is a modular function where m,,(v) = 1 if and only if u € U, and otherwise m.,(v) = 0. 
We see that this is a feature based submodular function since min(1, x) is concave in x, and U is a 
set of features. 

This construct can be used to produce the vertex cover function if we set U = V to be the set of 
vertices in a graph, and set m,,(v) = 1 if and only if vertices u and v are adjacent in the graph and 
otherwise set m,(v) = 1. Similarly, the edge cover function can be expressed by setting V to be the 
set of edges in a graph, U to be the set of vertices in the graph, and m,,(v) = 1 if and only edge v is 
incident to vertex u. 

A generalization of the set cover function is the probabilistic coverage function. Let P [Bu v = 1] be 
the probability of the presence of feature (or concept) u within element v. Here, we treat Bu, as a 
Bernoulli random variable for each element v and feature u so that P [Buw = 1] = 1 — P [Bu v = 0]. 
Then we can define the probabilistic coverage function as f(X) = X ueu fu(X) where, for feature 
u, we have fu(X) = 1—][],¢x(1 — P [Buw = 1]) which indicates the degree to which feature u is 
“covered” by X. If we set P [Buw = 1] = 1 if and only if u € U, and otherwise P[B,,,, = 1] = 0, then 
fu(X) = min(1,m,(X)) and the set cover function can be represented as X` ey fu(X). We can 
generalize this in two ways. First, to make it softer and more probabilistic we allow P [Bu œ = 1] to 
be any number between zero and one. We also allow each feature to have a non-negative weight. This 
yields the general form of the probabilistic coverage function, which is defined by taking a weighted 
combination over all features: fu(X) = J uey Wufu(X) where wu > 0 is a weight for feature u. 
Observe that 1—[],<-.(1— P [Bu = 1]) = 1 — exp(—mu(X)) = ¢(m,(X)) where m, is a modular 
function with evaluation my(X) = Xex log(1/(1—P [Buv = 1])) and for z € R, ¢(z) = 1—exp(—z) 


33 is a concave function. Thus, the probabilistic coverage function (and its set cover specialization) is 


also feature based function. 
Another common submodular function is the graph cut function. Here, we measure the value of a 


36 subset of V by the edges that cross between a set of nodes and all but that set of nodes. We are given 


an undirected non-negative weighted graph G = (V, E, w) where V is the set of nodes, E C V x V is 


38 the set of edges, and w € RY are non-negative edge weights corresponding to symmetric matrix (so 
39 Wij = wji). For any e € E, we have e = {i,j} for some i,j € V with i 4 j, the graph cut function 
40 f : 2V — R is defined as f(X) = J iex jex Wij Where wij > 0 is the weight of edge e = {i,j} 
41 (w; j = 0 if the edge does not exist), and where X = V \ X is the complement of set X. Notice that 
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we can write the graph cut function as follows: 
I= YO wis = do wiylfie X,j € X} (6.267) 
iE X, jEX i, jEV 
1 : a 1 F pu 1 
=3 J weg min(IXO {i} 1+3 J wij min(V\ XN IHD -— 5 Do wig (6-268) 
ijeV i jEV i jEV 
=f(X)+ fV VAI) (6.269) 
where f(X) = 2 Lijev wi j min(|X N {i,7}|, 1). Therefore, since min(a,1) is concave, and since 


mi (X) = |X N{i,j}| is modular, f(X) is submodular for all i, j. Also, since f(X) is submodular, so 
is f(V \ X) (in X). Therefore, the graph cut function can be expressed as a sum of non-normalized 
feature-based functions. Note that here the second modular function is not normalized and is 
non-increasing, and also we subtract the constant f(V) to achieve equality. 

Another way to view the graph cut function is to consider the non-negative weights as a modular 
function defined over the edges. That is, we view w € RE as a modular function w : 2° — Ry where 
for every A C E, w(A) = ¥0.<4 w(e) is the weight of the edges A where w(e) is the weight of edge 
e. Then the graph cut function becomes f(X) = w({(a,b) € EB: ae X,b © X \ X}). We view 
{(a,b) € EB: ae X,b € X \ X} as a set-to-set mapping function, that maps subsets of nodes to 
subsets of edges, and the edge weight modular function w measures the weight of the resulting edges. 
This immediately suggests that other functions can measure the weight of the resulting edges as 
well, including non-modular functions. One example is to use a polymatroid function itself leading 
h(X) = g({(a,b) € E :a € X,b € X \ X}) where g : 2% — R} is a submodular function defined 
on subsets of edges. The function h is known as the cooperative cut function, and it is neither 
submodular nor supermodular in general but there are many useful and practical algorithms that 
can be used to optimize it [JB16] thanks to its internal yet exposed and thus available to exploit 
submodular structure. 

While feature based functions are flexible and powerful, there is a strictly broader class of 
submodular functions, unable to be expressed by feature-based functions, and that are related to deep 
neural networks. Here, we create a recursively nested composition of concave functions with sums 
of compositions of concave functions. An example is f(A) = QC yey Wubu(dQcaca Mu,a)), Where ¢ 
is an outer concave function composed with a feature based function, with Mu, a > 0 and w, > 0. 
This is known as a two-layer deep submodular function (DSF). A three-layer DSF has the form 
F(A) = bX ecc Vep uey Wu,cPuld aca Mu,a))). DSFs strictly expand the class of submodular 
functions beyond feature based functions, meaning that there are feature based functions that can 
not[BB17] represent deep submodular functions, even simple ones. 


6.11.4 Submodular Optimization 


Submodular functions, while discrete, would not be very useful if it was not possible to optimize 
over them efficiently. There are many natural problems in machine learning that can be cast as 
submodular optimization and that can be addressed relatively efficiently. 

When one wishes to encourage diversity, information, spread, high complexity, independence, 
coverage, or dispersion, one usually will maximize a submodular function, in the form of max,4ec f(A) 
where C C 2” is a constraint set, a set of subsets we are willing to accept as feasible solutions (more 
on this below). 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e IW IN Ie 


= 
= 


le Is | 


306 


~~ Figure 6.20: Far Left: cardinality constrained (to ten) submodular maximization of a facility location function 
— over 1000 points in two dimensions. Similarities are based on a Gaussian kernel sim(a,v) = exp(—d(a, v)) 
= where d(-,-) is a distance. Selected points are green stars and the greedy order is also shown next to each 
= selected point. Right three plots: different uniformly-at-random subsets of size ten. 


Why is submodularity, in general, a good model for diversity? Submodular functions are such 


24 that once you have some elements, any other elements not in your possession but that are similar 


to, explained by, or represented by the elements in your possession become less valuable. Thus, in 


26 order to maximize the function, one must choose other elements that are dissimilar to, or not well 


represented by, the ones you already have. That is, the elements similar to the ones you own are 


28 diminished in value relative to their original values, while the elements dissimilar to the ones you 
29 have do not have diminished value relative to their original values. Thus, maximizing a submodular 
30 function successfully involves choosing elements that are jointly dissimilar amongst each other, which 
31 is a definition of diversity. Diversity in general is a critically important aspect in machine learning 
32 and artificial intelligence. For example, bias in data science and machine learning can often be seen 
33 as some lack of diversity somewhere. Submodular functions have the potential to encourage (and 
34 even ensure) diversity, enhance balance, and reduce bias in artificial intelligence. 


Note that in order for a submodular function to appropriately model diversity, it is important 


36 for it to be instantiated appropriately. Figure 6.20 shows an example in two dimensions. The plot 

37 compares the ten points chosen according to a facility location instantiated with a Gaussian kernel, 

38 along with the random samples of size ten. We see that the facility location selected points are more 

39 diverse and tend to cover the space much better than any of the randomly selected points each of 

40 which miss large regions of the space and/or show cases where points near each other are jointly 
selected. 


When one wishes for homogeneity, conformity, low complexity, coherence, or cooperation, one will 


43 usually minimize a submodular function, in the form of min4ec f(A). For example, if V is a set of 
44 pixels in an image, one might wish to choose a subset of pixels corresponding to a particular object 
45 over which the properties (i.e., color, luminance, texture) are relatively homogeneous. Finding a set 
46 X of size k, even if k is large, need not have a large valuation f(X), in fact it could even have the 
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least valuation. Thus, semantic image segmentation could work even of the object being segmented 
and isolated consists of the majority of image pixels. 


6.11.4.1 Submodular Maximization 


While the cardinality constrained submodular maximization problem is NP complete [Fei98], it was 
shown in [NWF78; FNW78] that the very simple and efficient greedy algorithm finds an approximate 
solution guaranteed to be within 1—1/e ~ 0.63 of the optimal solution. Moreover, the approximation 
ratio achieved by the simple greedy algorithm is provably the best achievable in polynomial time, 
assuming P Æ NP [Fei98]. The greedy algorithm proceeds as follows: Starting with Xo = 0, we 
repeat the following greedy step for i =0...(k — 1): 


Xi+ı = Xi U (argmax f(Xi U {v})) (6.270) 
vEV\ Xi 


What the above approximation result means is that if X* € argmax{ f(X) : |X| < k}, and if X is 
the result of the greedy procedure, then f(X) > (1 — 1/e)f(X*). 

The 1 — 1/e guarantee is a powerful constant factor approximation result since it holds regardless 
of the size of the initial set V and regardless of which polymatroid function f is being optimized. 
It is possible to make this algorithm run extremely fast using various acceleration tricks [FNW/78; 
NWF’78; Min78]. 

A minor bit of additional information about a polymatroid function, however, can improve 
the approximation guarantee. Define the total curvature if the polymatroid function f as k = 
1—minyey f(v|V —v)/f(v) where we assume f(v) > 0 for all v (if not, we may prune them from the 
ground set since such elements can never improve a polymatroid function valuation). We thus have 
0 < k <1, and [CC84] showed that the greedy algorithm gives a guarantee of +(1 — e7“) > 1 — 1/e. 
In fact, this is an equality (and we get the same bound) when « = 1, which is the fully curved case. 
As «K gets smaller, the bound improves, until we reach the « = 0 case and the bound becomes unity. 
Observe that « = 0 if and only if the function is modular, in which case the greedy algorithm is 
optimal for the cardinality constrained maximization problem. In some cases, non-submodular 
functions can be decomposed into components that each might be more amenable to approximation. 
We see below that any set function can be written as a difference of submodular [NB05; IB12| 
functions, and sometimes (but not always) a given h can be composed into a monotone submodular 
plus a monotone supermodular function, or a BP function [BB18], i.e., h = f +g where f is 
submodular and g is supermodular. g has an easily computed quantity called the supermodular 
curvature KI = 1—minyey g(v)/g(v|V — v) that, together with the submodular curvature, can be 
used to produce an approximation ratio having the form 4+(1 — e-*U—*")) for greedily maximization 
of h. 


6.11.4.2 Discrete Constraints 


There are many other types of constraints one might desire besides a cardinality limitation. The next 
simplest constraint allows each element v to have a non-negative cost, say m(v) € Ry. In fact, this 
means that the costs are modular, i.e., the cost of any set X is m(X) = J „ex m(v). A submodular 
maximization problem subject to a knapsack constraint then takes the form maxxcv:m(x)<» f(X) 
where b is a non-negative budget. While the greedy algorithm does not solve this problem directly, a 
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slightly modified cost-scaled version of the greedy algorithm [Svi04] does solve this problem for any 
set of knapsack costs. This has been used for various multi-document summarization tasks [LB11; 
LB12]. 

There is no single direct analogy for a convex set when one is optimizing over subsets of the set V, 
but there are a few forms of discrete constraints that are both mathematically interesting and that 
often occur repeatedly in applications. 

The first form are the independent subsets of a matroids. The independent sets of a matroid 
are useful to represent a constraint set for submodular maximization [Cal+07; LSV09; Lee+10], 
maxyez f(X), and this can be useful in many ways. We can see this by showing a simple example 
of what is known as a partition matroid. Consider a partition V = {V,,V2,...,Vm} of V into m 
mutually disjoint subsets that we call blocks. Suppose also that for each of the m blocks, there is a 
positive integer limit 4; for i € [m]. Consider next the set of sets formed by taking all subsets of V 
such that each subset has intersection with V; no more than ¢; for each i. I.e., consider 


Tp = {X : Vi € [m], |V N X| < Gi}. (6.271) 


Then (V,Z,) is a matroid. The corresponding submodular maximization problem is a natural 
generalization of the cardinality constraint in that, rather than having a fixed number of elements 
beyond which we are uninterested, the set of elements V is organized into groups, and here we have 
a fixed per-group limit beyond which we are uninterested. This is useful for fairness applications 
since the solution must be distributed over the blocks of the matroid. Still, there are many much 
more powerful types of matroids that one can use [Oxl11; GM12]. 

Regardless of the matroid, the problem maxyez f(X) can be solved, with a 1/2 approximation 
factor, using the same greedy algorithm as above [NWF78; FNW78]. Indeed, the greedy algorithm 
has an intimate relationship with submodularity, a fact that is well studied in some of the seminal 
works on submodularity [Edm70; Lov83; Sch04]. It is also possible to define constraints consisting 


27 of an intersection of matroids, meaning that the solution must be simultaneously independent in 
28 multiple distinct matroids. Adding on to this, we might wish a set to be independent in multiple 
29 matroids and also satisfy a knapsack constraint. Knapsack constraints are not matroid constraints, 
30 since there can be multiple maximal cost solutions that are not the same size (as must be the case in 
31 a matroid). It is also possible to define discrete constraints using level sets of another completely 
32 different submodular function [[B13] — given two submodular functions f and g, this leads to 
33 optimization problems of the form maxx cy:g(x)<a f(X) (the submodular cost submodular knapsack, 
34 or SCSK, problem) and minycy:g(x)>a f(X) (the submodular cost submodular cover, or SCSC, 
35 problem). Other examples include covering constraints [IN09], and cut constraints [JB16]. Indeed, 
36 the type of constraints on submodular maximization for which good and scalable algorithms exist is 
37 quite vast, and still growing. 


One last note on submodular maximization. In the above, the function f has been assumed to be 


39 a polymatroid function. There are many submodular functions that are not monotone [Buc+12]. 
40 One example we saw before, namely the graph cut function. Another example is the log of the 
41 determinant (log-determinant) of a submatrix of a positive-definite matrix (which is the Gaussian 
42 entropy plus a constant). Suppose that M is an n x n symmetric positive-definite (SPD) matrix, 
43 and that Mx is a row-column submatrix (i.e., it is an |X| x |X| matrix consisting of the rows and 
44 columns of M consisting of the elements in X). Then the function defined as f(X) = log det(Mx) 
45 is submodular but not necessarily monotone non-decreasing. In fact, the submodularity of the 
46 log-determinant function is one of the reasons that determinantal point processes (DPPs), which 
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instantiate probability distributions over sets in such a way that high probability is given to those 
subsets that are diverse according to M, are useful for certain tasks where we wish to probabilistically 
model diversity [KT11]. Diversity of a set X here is measured by the volume of the paralellepiped 
which is known to be computed as the determinant of the submatrix M y and taking the log of this 
volume makes the function submodular in X. A DPP in fact is an example of a log-submodular 
probabilistic model (more in Section 6.11.10). 


6.11.4.3 Submodular Function Minimization 


In the case of a polymatroid function, unconstrained minimization is again trivial. However, even in 
the unconstrained case, the minimization of an arbitrary (i.e., not necessarily monotone) submodular 
function minxcy f(X) might seem hopelessly intractable. Unconstrained submodular maximization 
is NP-hard (albeit approximable) and this is not surprising given that there are an exponential 
number of sets needing to be considered. Remarkably, submodular minimization does not require 
exponential computation, is not NP-hard, and in fact, there are polynomial time algorithms for 
doing so, something that is not at all obvious. This is one of the important characteristics that 
submodular functions share with convex functions, their common amenability to minimization. 
Starting in the very late 1960s, and spearheaded by individuals such as Jack Edmonds [Edm70], 
there was a concerted effort in the discrete mathematics community in search of either an algorithm 
that could minimize a submodular function in polynomial time or a proof that such a problem was 
NP-hard. The nut was finally cracked in a classic paper [GLS81] on the ellipsoid algorithm that gave 
a polynomial time algorithm for submodular function minimization (SFM). While the algorithm was 
polynomial, it was a continuous algorithm, and it was not practical, so the search continued for a 
purely combinatorial strongly polynomial time algorithm. Queyranne [Que98] then proved that an 
algorithm [NI92] worked for this problem when the set function also satisfies a symmetry condition 
(ie., VX C V, f(X) = f(V \ X)), which only requires O(n?) time. The result finally came in around 
year 2000 using two mostly independent methods [IFF00; Sch00]. These algorithms, however, also 
were impractical in that while they are polynomial time, they have unrealistically high polynomial 
degree (i.e., O(V7 * y + V8) for [Sch00] and O(V7 x y) for [IFF00]). This led to additional work on 
combinatorial algorithms for SFM leading to algorithms that could perform SFM in time O(V°7+V°) 
in [1009]. Two practical algorithms for SFM include the Fujishige-Wolfe procedure [Fuj05; Wol76]’ 
as well as the Frank-Wolfe procedure, each of which minimize the 2-norm on a polyhedron By 
associated with the submodular function f and which is defined below (it should also be noted 
that the Frank-Wolfe algorithm can also be used to minimize the convex extension of the function, 
something that is relatively easy to compute via the Lovász extension [Lov83]). More recent work on 
SFM are also based continuous relaxations of the problem in some form or another, leading algorithms 
with strongly polynomial running time [LSW15] of O(n? log? n) for which it was possible to drop the 
log factors leading to a complexity of O(n?) in [Jia21], weakly-polynomial running time [LSW15] of 
O(n? log M) (where M >= maxgcy |f(S)|), pseudo-polynomial running time [ALS20; Cha+17| of 
O(nM?), and a © approximate minimization with a linear running time [ALS20] of O(n/e?). There 
have been other efforts to utilize parallelism to further improve SFM [BS20]. 


7. This is the same Wolfe as the Wolfe in Frank-Wolfe but not the same algorithm. 
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6.11.5 Applications of Submodularity in Machine Learning and AI 


Submodularity arises naturally in applications in machine learning and artificial intelligence, but its 
utility has still not yet been as widely recognized and exploited as other techniques. For example, 
while information theoretic concepts like entropy and mutual information are extremely widely used 
in machine learning (e.g., the cross entropy loss for classification is ubiquitous), the submodularity 
property of entropy is not nearly as widely explored. 

Still, the last several decades, submodularity has been increasingly studied and utilized in the 
context of machine learning. The below we begin to provide only a brief survey of some of the major 
sub-areas within machine learning that have been touched by submodularity. The list is not meant 
to be exhaustive, or even extensive. It is hoped that the below should, at least, offer a reasonable 
introduction into how submodularity has been and can continue to be useful in machine learning and 
artificial intelligence. 


6.11.6 Sketching, CoreSets, Distillation, and Data Subset & Feature Selection 


A summary is a concise representation of a body of data that can be used as an effective and efficient 
substitute for that data. There are many types of summary and some are extremely simple. For 
example, the mean or median of a list of numbers summarizes some property (the central tendency) 
of that list. A random subset is also a form of summary. 

Any given summary, however, is not guaranteed to do a good job serving all purposes. Moreover, 
a summary usually involves at least some degree of approximation and fidelity loss relative to the 
original, and different summaries are faithful to the original in different ways and for different tasks. 
For these and other reasons, the field of summarization is rich and diverse, and summarization 
procedures are often very specialized. 

Several distinct names for summarization have been used over the past few decades, including 


28 “sketches”, “coresets”, (in the field of natural language processing) “summaries”, and “distillation.” 


Sketches [Cor17; CY20; Cor+12], arose in the field of computer science and was based on the 


30 acknowledgment that data is often too large to fit in memory and too large for an algorithm to run 
31 on a given machine, something enabled by a much smaller but still representative, and provably 
32 approximate, representation of the data. 


Coresets are similar to sketches and there are some properties that are more often associated 


34 with coresets than with sketches, but sometimes the distinction is a bit vague. The notion of a 
35 coreset [BHP1I02; AHP+05; BC08] comes from the field of computational geometry where one is 
36 interested in solving certain geometric problems based on a set of points in R¢. For any geometric 
37 problem and a set of points, a coreset problem typically involves finding the smallest weighted subset 
38 of points so that when an algorithm is run on the weighted subset, it produces approximately the 
39 same answer as when it is run on the original large dataset. For example, given a set of points, one 
40 might wish to find the diameter of set, or the radius of the smallest enclosing sphere, or finding the 
41 narrowest annulus (ring) containing the points, or a subset of points whose k-center clustering is 
42 approximately the same as the k-center clustering of the whole [BHPI02]. 


Document summarization became one of the most important problems in natural language process- 


44 ing (NLP) in the 1990s although the idea of computing a summary of a text goes back much further to 
45 the 1950s [Luh58; Edm69], also and coincidentally around the same time that the CliffsNotes [Wik21] 
46 organization began. There are two main forms of document summarization [YWX17]|. With extractive 
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summarization [NM12], a set of sentences (or phrases) are extracted from the documents needing to 
be summarized, and the resulting subset of sentences, perhaps appropriately ordered, comprises the 
summary. 

With abstractive summarization [LN19], on the other hand, the goal is to produce an “abstract” of 
the documents, where one is not constrained to have any of the sentences in the abstract correspond 
to any of the sentences in the original documents. With abstractive summarization, therefore, the 
goal is to synthesize a small set of new pseudo sentences that represent the original documents. 
CliffsNotes, for example, are abstractive summaries of the literature being represented. 

Another form of summarization that has more recently become popular in the machine learning 
community is data distillation [SGO6b; Wan+20c; Suc+20; BYH20; NCL20; SS21; Ngu+21] or 
equivalently dataset condensation [ZMB21; ZB21]. With data distillation’, the goal is to produce a 
small set of synthetic pseudo-samples that can be used, for example, to train a model. The key here 
is that in the reduced dataset, the samples are not compelled to be the same as, or a subset of, the 
original dataset. 

All of the above should be contrasted with data compression, which in some sense is the most 
extreme data reduction method. With compression, either lossless or lossy, one is no longer under 
any obligation that the reduced form of the data need to be used or recognizable by any algorithm or 
entity other than the decoder, or uncompression, algorithm. 


6.11.6.1 Summarization Algorithm Design Choices 


It is the author’s contention that the notions of summarization, coresets, sketching, and distillation 
are certainly analogous and quite possibly synonymous, and they are all different from compression. 
The different names for summarization are simply different nomenclatures for the same language 
game. What matters is not what you call it but the choices one makes when designing a procedure 
for summarization. And indeed, there are many choices. 

Submodularity offers essentially an infinite number ways to perform data sketching and coresets. 
When we view the submodular function as an information function (as we discuss in Section 6.11.7), 
where f(X) is the information contained in set X and f(V) is the maximum available information, 
finding the small X that maximizes f(X) (i.e., X* € argmax{ f(X) : |X| < k}), is a form of coreset 
computation that is parameterized by the function f which has 2” parameters since f lives in a 
2”-dimensional cone. Performing this maximization will then minimize the residual information 
f(V \ X|X) about anything not present the summary V \ X since f(V) = f(X UV \ X) = 
f(V \ X|X)+ f(X) so maximizing f(X) will minimize f(V \ X|X). For every f, moreover, the same 
algorithm (e.g., the greedy algorithm) can be used to produce the summarization, and in every case 
there is an approximation guarantee relative to the current f, as mentioned in earlier sections, as long 
as f stays submodular. Hence, submodularity provides a universal framework for summarization, 
coresets, and sketches to the extent that the space of submodular functions itself is sufficiently 
diverse and spans over different coreset problems. 

Overall, the corset or sketching problem, when using submodular functions, therefore becomes 
a problem of “submodular design.” That is, how do we construct submodular function that, for a 
particular problem, acts as a good coreset producer when the function is maximized. There are three 
general approaches to produce an f that works well as a summarization objective: (1) a pragmatic 


8. Data distillation is distinct from the notion of knowledge distillation [HVD14; BC14; BCNM06] or model distillation, 
where the “knowledge” contained in a large model is distilled or reduced down into a different smaller model. 
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approach where the function is constructed by hand and heuristics, (2) a learning approach where all 
or part of the submodular function is inferred from an optimization procedure, and (3) a mathematical 
approach where a given submodular function when optimized offers of a coreset property. 

When the primary goal is a practical and scalable algorithm that can produce an extractive 
summary that works well on a variety of different data types, and if one is comfortable with heuristics 
that work well in practice, a good option is to specify a submodular function by hand. For example, 
given a similarity matrix, it is easy to instantiate a facility location function and maximize it to 
produce a summary. If there are multiple similarity matrices, one can construct multiple facility 
location functions and maximize their convex combination. Such an approach is viable and practical 
and has been used successfully many times in the past for producing good summaries. One of the 
earliest examples of this the algorithm presented in [KKT03] that shows how a submodular model can 
be used to select the most influential nodes in a social network. Perhaps the earliest example of this 
approach used for data subset selection for machine learning is [LB09] which utilizes a submodular 
facility location function based on Fisher kernels (gradients w.r.t. parameters of log probabilities) 
and applies it to unsupervised speech selection to reduce transcription costs. Other examples of 
this approach includes: [LB10a; LB11] which developed submodular functions for query-focused 
document summarization; [KB14b] which computes a subset of training data in the context of 
transductive learning in a statistical machine translation system; [LB10b; Wei+13; Wei+14] which 
develops submodular functions for speech data subset selection (the former, incidentally, is the first 
use of a deep submodular function and the latter does this in an unsupervised label-free fashion); 
[SS18a] which is a form of robust submodularity for producing coresets for training CNNs; [Kau+19] 
which uses a facility location to facilitate diversity selection in active learning; [Bai+15; CTN17] 
which develops a mixture of submodular functions for document summarization where the mixture 
coefficients are also included in the hyperparameter set; [Xu+15] uses a symmetrized submodular 
function for the purposes of video summarization. 

The learnability and identifiability of submodular functions has received a good amount of study 
from a theoretical perspective. Starting with the strictest learning settings, the problem looks pretty 
dire. For example, [SF08; Goe+09] shows that if one is restricted to making a polynomial number of 
queries (i.e., training pairs of the form (S, f(S))) of a monotone submodular function, then it is not 


31 possible to approximate f with a multiplicative approximation factor better than O(n). In [BH11], 


goodness is judged multiplicatively, meaning for a set A C V we wish that f(A) < f(A) < g(n)f(A) 


33 for some function g(n), and this is typically a probabilistic condition (i.e., measured by distribution, 


or f(A) < f(A) < g(n) f(A), should happen on a fraction at least 1 — 8 of the points). Alternatively, 
eee may also be measured by an additive approximation error, say by a norm. I.e., defining 


36 ertp(f, D = || — fille = (Eaxpsllf(A) — F(A)" YP, we may wish errp(f, f) < e for p= 1 or p = 2. 


In the PAC (probably approximately correct) model, we probably (ô > 0) approximately (€ > 0 or 


38 g(n) > 1) learn (8 = 0) with a sample or algorithmic complexity that depends on 6 and g(n). In the 


PMAC (probably mostly approximately correct) model [BH11], we also “mostly” 6 > 0 learn. In some 
cases we wish to learn the best submodular approximation to a non-submodular function. In other 
cases, we are allowed to deviate from submodularity as long as the error is small. Learning special 
cases includes coverage functions [FK14; FK13a], and low-degree polynomials [FV15], curvature 
limited functions [1JB13], functions with a limited “goal” [DHK14; Bac+18], functions that are 
Fourier sparse [Wen+20a], or that are of a family called “juntas” [FV16], or that come from families 
other than submodular [DFF 21], and still others [BRS17; FKV14; FKV17; FKV20; FKV13; YZ19]. 
Other results include that one can not minimize a submodular function by learning it first from 
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samples [BS17]. The essential strategy of learning is to attempt to construct a submodular function 
approximation f from an underlying submodular function f querying the latter only a small number 
of times. The overall gist of these results is that it is hard to learn everywhere and accurately. 

In the machine learning community, learning can be performed extremely efficiently in practice, 
although there are not the types of guarantees as one finds above. For example, given a mixture 
of submodular components of the form f(A) = >>, a: fi(A), if each f; is considered fixed, then the 
learning occurs only over the mixture coefficients a;. This can be solved as a linear regression problem 
where the optimal coefficients can be computed in a linear regression setting. Alternatively, such 
functions can be learnt in a max-margin setting where the goal is primarily to adjust a; to ensure 
that f(A) is large on certain subsets [SSJ12; LB12; Tsc+14]. Even here there are practical challenges, 
however, since it is in general hard in practice to obtain a training set of pairs {(S;, F(S;))};. 
Alternatively, one also “learn” a submodular function in a reinforcement learning setting [CKK17] by 
optimizing the implicit function directly from gain vectors queried from an environment. In general, 
such practical learning algorithms have been used for image summarization [Tsc+14], document 
summarization [LB12], and video summarization [GGG15; Vas+17a; Gon+14; SGS16; SLG17]. While 
none of these learning approaches claim to approximate some true underlying submodular function, 
in practice, they do perform better than the by-hand crafting of a submodular functions mentioned 
above. 

By a submodularity based coreset, we mean one where the direct optimization of a submodular 
function offers a theoretical guarantee for some specific problem. This is distinct from above where 
the submodular function is used as a surrogate heuristic objective function and for which, even if the 
submodular function is learnt, optimizing it is only a heuristic for the original problem. In some 
limited cases, it can be shown that the function we wish to approximate is already submodular, 
e.g., in the case of certain naive Bayes and k-NN classifiers [WIB15] where the training accuracy, 
as a function of the training data subset, can be shown to be submodular. Hence, maximizing this 
function offers the same guarantee on the training accuracy as it does on the submodular function. 
Unfortunately, the accuracy function for many models are not submodular, although they do have a 
difference of submodular [NB05; [B12] decomposition. 

In other cases, it can be shown that certain desirable coreset objectives are inherently submodular. 
For example, in [MBL20], it is shown that the normed difference between the overall gradient 
(from summing over all samples in the training data) and an approximate gradient (from summing 
over only samples in a summary) can be upper bounded with a supermodular function that, when 
converted to a submodular facility location function and maximized, will select a set that reduces 
this difference, and will lead to similar convergence rates to an approximate optimum solution in the 
convex case. A similar example of this in a DPP context is shown in [TBA19]. In other cases, subsets 
of the training data and training occur simultaneously using a continuous-discrete optimization 
framework, where the goal is to minimize the loss on diverse and challenging samples measured by a 
submodular objective [ZB18]. In still other cases, bi-level objectives related to but not guaranteed to 
be submodular can be formed where a set is selected from a training set with the deliberate purpose 
of doing well on a validation set [Kil+20; BMK20]. 

The methods above have focused on reducing the number of samples in a training data. Considering 
the transpose of a design matrix, however, all of the above methods can be used for reducing the 
features of a machine learning procedure as well. Specifically, any of the extractive summarization, 
subset selection, or coreset methods can be seen as feature selection while any of the abstract 
summarization, sketching, or distillation approaches can be seen as dimensionality reduction. 
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6.11.7 Combinatorial Information Functions 


The entropy function over a set of random variables X1, X2,..., Xn is defined as H(X1, X2,..., Xn) = 
— eS p(£1,..., £n) log p(z1,..., £n). From this we can define three set-argument conditional 
mutual information functions as Iq(A;B|C) = I(X4;Xp|Xc) where the latter is the mutual 
information between variables indexed by A and B given variables indexed by C. This mutual 
information expresses the residual information between X4 and Xp that is not explained by their 
common information with Xc. 

As mentioned above, we may view any polymatroid function as a type of information function over 
subsets of V. That is, f(A) is the information in set A — to the extent that this is true, this property 
justifies f’s use as a summarization objective as mentioned above. The reason f may be viewed as an 
information function stems from f being normalized, f’s non-negativity, f’s monotonicity, and the 
property that further conditioning reduces valuation (i.e., f(A|B) > f(A| B,C) which is identical to 
the submodularity property). These properties were outlines as being essential to the entropy function 
in Shannon’s original work [Sha48] but are true of any polymatroid function as well. Hence, given any 
polymatroid function f, is possible to define a combinatorial mutual information function [lye+21] 
in a similar way. Specifically, we can define the combinatorial (submodular) conditional mutual 
information (CCMI) as If(A; B|C) = f(A+C)+ f(B+C) — f(C)— f(A+B+C), which has been 
known as the connectivity function [Cun83] amongst other names. If f is the entropy function then 
this yields the standard entropic mutual information but here the mutual information can be defined 
for any submodular information measure f. For an arbitrary polymatroid f, therefore, I7(A; B|C) 
can be seen as an A, B set-pair similarity score that ignores, neglects, or discounts any common 
similarity between the A, B pair that is due to C. 

Historical use of a special case of CCMI, i.e., Ip(A; B) where C = Q, occurred in a number of 
circumstances. For example, in [GKS05] the function g(A) = Ip(A; V \ A) (which, incidentally 


27 is both symmetric (g(A) = g(V \ A) for all A) and submodular was optimized using the greedy 
28 procedure which has a guarantee as long as g(A) is monotone up 2k elements whenever one wishes 
29 for a summary of size k. This was done for f being the entropy function, but it can be used for 
30 any polymatroid function. In similar work where f is Shannon entropy, [KG05] demonstrated that 
31 gc(A) = I;(A;C) for a fixed set C is not submodular in A but if it is the case that the elements of 


V are independent given C then submodularity is preserved. This can be seen quickly easily by the 


33 consequence of the assumption which states that I¢(A;C) = f(A) — f(AIC) = f(A) — Yaca f(al©) 
34 where the second equality is due to the conditional independence property. In this case, Ip is the 
35 difference between a submodular and a modular function which preserves submodularity for any 
36 polymatroid f. 


On the other hand, it would be useful for gg,c(A) = I(A; B|C), where B and C are fixed, to be 


38 possible to optimize in terms of A. One can view this function as one that, when it is maximized, 
39 chooses A to be similar to B in a way that neglects or discounts any common similarity that A and 
40 B have with C. One option to optimize this function to utilize difference of submodular [NB05; [B12] 
41 optimization as mentioned earlier. A more recent result shows that in some cases gp.c(A) is still 
42 submodular in A. Define the second order partial derivative of a submodular function f as follows 
43 fli ilS) = FGS ++i) — f(j|S). Then if it is the case that f(i,j|S) is monotone non-decreasing in S 
44 for S CV \ {i,j} then I-(A; B|C) is submodular in A for fixed B and C. It may be thought that 
45 only esoteric functions have this property but in fact [Iye+21] shows that this is true for a number of 
46 widely used submodular functions in practice, including the facility location function which results 
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in the form I;(A; BIC) = Ypey max(min(S ye, sim(v, a), maxpep sim(v, b)) — maxcec sim(v, c), 0) : 
This function was used [Kot-+22] to produce summaries A that were particularly relevant to a query 


given by B but that should neglect information in C that can be considered “private” information to 
avoid. 


6.11.8 Clustering, Data Partitioning, and Parallel Machine Learning 


There are an almost limited number of clustering algorithms and a plethora of reviews on their 
variants. Any given submodular function can also instantiate a clustering procedure as well, and 
there are several ways to do this. Here we offer only a brief outline of the approach. In the last 
section, we defined I7(A;V \ A) as the CCMI between A and everything but A. When we view this 
as a function of A, then g(A) = I;(A; V \ A) and g(A) is a symmetric submodular function that 
can be minimized using Queyranne’s algorithm [Que98; NI92]. Once this is done, the resulting A 
is such that it is least similar to V \ A, according to Iy(A;V \ A) and hence forms a 2-clustering. 
This process can then be recursively applied where we form two new functions g4(B) = I7(B;A\ B) 
for B C A and gy\ 4(B) = I(B; (V \ A) \ B) for BC V \ A. These are two symmetric submodular 
functions on different ground sets that also can be minimized using Queyranne’s algorithm. This 
recursive bisection algorithm then repeats until the desired number of clusters is formed. Hence, the 
CCMI function can be used as a top-down recursive bisection clustering procedure and has been 
called Q-clustering [NJB05; NBO6]. It should be noted that such forms of clustering often generalizes 
forming a multi-way cut in an undirected graph in which case the objective becomes the graph-cut 
function that, as we saw above, is also submodular. In some cases, the number of clusters need 
not be specified in advance [NKI10]. Another submodular approach to clustering can be found 
in [Wei+15b] where the goal is to minimize the maximum valued block in a partitioning which can 
lead to submodular load balancing or minimum makespan scheduling [HS88; LST90]. 

Yet another form of clustering can be seen via the simple cardinality constrained submodular 
maximization process itself which can be compared to a k-medoids process whenever the objective f 
is the facility location function. Hence, any such submodular function can be seen as a submodular- 
function-parameterized form of finding the k “centers” among a set of data items. There have been 
numerous applications of submodular clustering. For example, using these techniques it is possible to 
identify parcellations of the human brain [Sal+17a]. Other applications include partitioning data for 
more effective and accurate and lower variance distributed machine learning training [Wei+15a] and 
also for more ideal mini-batch construction for training deep neural networks [Wan-+19b]. 


6.11.9 Active and Semi-Supervised Learning 


Suppose we are given data set {x;,yi}iev consisting of |V| = n samples of x,y pairs but where 
the labels are unknown. Samples are labeled one at a time or one mini-batch at a time, and after 
each labeling step t each remaining unlabeled sample is given a score s;(a;) that indicates the 
potential benefit of acquiring a label for that sample. Examples include the entropy of the model’s 
output distribution on x;, or a margin based score consisting of the difference between the top and 
the second-from-the-top posterior probability. This produces a modular function on the unlabeled 
samples, m;(A) = Jac 8(%a) where A C V. It is simple to use this modular function to produce 
a mini-batch active learning procedure where at each stage we form A; € argmax 4cy,.| 4\=% (A) 
where U; is the set of labeled samples at stage t. Then A; is a set of size k that gets labeled, we form 
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U; = U; \ At, update s,(a) for a € U; and repeat. This is called active learning (Section 34.7). 

The reason for using active learning with mini-batches of size greater than one is that it is often 
inefficient to ask for single label at a time. The problem with such a minibatch strategy, however, 
is that the set A; can be redundant. The reason is that the uncertainty about every sample in A, 
could be owing to the same underlying cause — even though the model is most uncertain about 
samples in A;, once one sample in A; is labeled, it may not be optimal to label the remaining samples 
in A; due to this redundancy. Utilizing submodularity, therefore, can help reduce this redundancy. 
Suppose f;(A) is a submodular diversity model over samples at step t. At each stage, choosing 
the set of samples to label becomes A; € argmax ycy,.)4)=~ Mt(A) + fr(A) — At is selected based 
on a combination of both uncertainty (via m;(A)) and diversity (via f;(A)). This is precisely the 
submodular active learning approach taken in [WIB15; Kau+19]. 

Another quite different approach to a form of submodular “batch” active learning setting where a 
batch L of labeled samples are selected all at once and then used to label the rest of the unlabeled 
samples. This also allows the remaining unlabeled samples to be utilized in a semi-supervised 
framework [GB09; GB11]. In this setting, we start with a graph G = (V, E) where the nodes V 
need to be given a binary {0,1}-valued label, y € {0,1}”. For any A C V let ya € {0,1}4 be 
the labels just for node set A. We also define V(y) C V as V(y) = {v € V : w = 1}. Hence 
V(y) are the graph nodes labeled 1 by y and V \ V(y) are the nodes labeled 0. Given submodular 
objective f, we form its symmetric CCMI variant I7(A) = I¢(A;V \ A) — note that I;(A) is always 
submodular in A. This allows I7(V(y)) to determine the “smoothness” of a given candidate labeling 
y. For example, if I is the weighted graph cut function where each weight corresponds to an affinity 
between the corresponding two nodes, then I-(V(y)) would be small if V(y) (the 1-labeled nodes) 
do not have strong affinity with V \ V (y) (the 0-labeled nodes). In general, however, Iş can be any 
symmetric submodular function. Let L C V be any candidate set of nodes to be labeled, and define 
U(L) = mingcy\1):r40 L¢(T)/|T|. Then U(L) measures the “strength” of L in that if W(L) is small, 
an adversary can label nodes other than L without being too unsmooth according to I, while if 
W(L) is large, an adversary can do no such thing. Then [GB11] showed that given a node set L to be 
queried, and the corresponding correct labels yz, that are completed (in a semi-supervised fashion) 


according to the following y’ = argmingeso 1}v.9,=y, [¢(V(g)), then this results in the following 


31 bound on the true labeling ||y — y’||? < 2/,(V(y))/W(L) suggesting that we can find a good set 
32 to query by maximizing L in (ZL), and this holds for any submodular function. Of course it is 
33 necessary to find an underlying submodular function f that fits a given problem, and this is discussed 


in Section 6.11.6. 


— 6.11.10 Probabilistic Modeling 


38 Graphical models are often used to describe factorization requirements on families of probability 
39 distributions. Factorization is not the only way, however, to describe restrictions on such families. 
40 In a graphical model, graphs describe only which random variable may directly interact with other 
41 random variable. An entirely different strategy for producing families of often-tractable probabilistic 
42 models can be produced without requiring any factorization property at all. Considering an energy 
43 function E(x) where p(x) x exp((x)), factorizations correspond to there being cliques in the graph 


such that the graph’s tree-width often is limited. On the other hand, finding max, p(x) is the same 


45 as finding min, E(x), something that can be done if E(x) = f(V(x)) is a submodular function (using 
46 the earlier used notation V(x) to map from binary vectors to subsets of V). Even a submodular 
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function as simple as f(A) = \/|A|—m/(A) where m is modular has tree-width of n— 1, and this leads 
to an energy function E(x) that allows max, p(x) to be solved in polynomial time using submodular 
function minimization (see Section 6.11.4.3). Such restrictions to E(x) therefore are not of the form 
amongst the random variables, who is allowed to directly interact with whom, but rather amongst the 
random variables, what is the manner that they interact. Such potential function restrictions can also 
combine with direct interaction restrictions as well and this has been widely used in computer vision, 
leading to cases where graph-cut and graph-cut like “move making” algorithms (such as alpha-beta 
swap and alpha-expansion algorithms) used in attractive models [BVZ99; BK01; BVZ01; SWWO8]. 
In fact, the culmination of these efforts [KZ02] lead to a rediscovery of the submodularity (or the 
“regular” property) as being the essential ingredient for when Markov random fields can be solved 
using graph cut minimization, which is a special case of submodular function minimization. 

The above model can be seen as log-supermodular since log p(x) = —E (x) + log1/Z is a super- 
modular function. These are all distributions that put high probability on configurations that yield 
small valuation by a submodular function. Therefore, these distributions have high probability when 
x consists of a homogeneous and for this reason they are useful for computer vision segmentation 
problems (e.g., in a segment of an image, the nearby pixels should roughly be homogeneous as 
that is often what defines an object). The DPPs we saw above, however, are an example of a 
log-submodular probability distribution since f(X) = log det(M x) is submodular. These models 
have high probability for diverse sets. 

More generally, E(x) being either a submodular or supermodular function can produce log- 
submodular or log-supermodular distributions, covering both cases above where the partition function 
takes the form Z = > ,cy exp(f(A)) for objective f. Moreover, we often wish to perform tasks much 
more than just finding the most probable random variable assignments. This includes marginalization, 
computing the partition function, constrained maximization, and so on. Unfortunately, many of these 
more general probabilistic inference problems do not have polynomial time solutions even though 
the objectives are submodular or supermodular. On the other hand, such structure has opened the 
doors to an assortment of new probabilistic inference procedures that exploit this structure [DK 14; 
DK15a; DTK16; ZDK15; DJK18]. Most of these methods were of the variational sort and offered 
bounds on the partition function Z, sometimes making use of the fact that submodular functions 
have easily computable semi-gradients [[B15; Fuj05] which are modular upper and lower bounds on a 
submodular or supermodular function that are tight at one or more subsets. Given a submodular (or 
supermodular) function f and a set A, it is possible to easily construct (in linear time) a modular 
function upperbound m^ : 2Y > R and a modular function lower bound m, : 2V > R having 
the properties that m4(X) < f(X) < m4(X) for all X C V and that is tight at X = A meaning 
ma(A) = f(A) = m4(A) [IB15]. For any modular function m, the probability function for a 
characteristic vector x = 14 becomes p(14) = 1/Zexp(E(1a4)) = Iaca 7(™()) Haga 7(—m(a)) 
where ø is the logistic function. Thus, a modular approximation of a submodular function is like a 
mean-field approximation of the distribution, and makes the assumption that all random variables 
are independent. Such an approximation can then be used to compute quantities such as upper and 
lower bounds on the partition function, and much else. 


6.11.11 Structured Norms and Loss Functions 


Convex norms are used ubiquitously in machine learning, often as complexity penalizing regularizers 
(e.g., the ubiquitous p-norms for p > 1) and also sometimes as losses (e.g., squared error). Identifying 
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new useful structured and possibly learnable sparse norms is an interesting and useful endeavor, 
and submodularity can help here as well. Firstly, recall the Zo or counting norm ||z||9 simply counts 
the number of nonzero entries in x. When we wish for a sparse solution, we may wish to regularize 
using ||z||o but it both leads to an intractable combinatorial optimization problem and it leads to an 
object that is not differentiable. The usual approach is to find the closest convex relaxation of this 
norm and that is the one norm or ||z||,. This is convex in x and has a sub-gradient structure and 
hence can be combined with a loss function to produce an optimizable machine learning objective, 
for example the lasso. On the other hand ||z||ı has no structure, as each element of x is penalized 
based on its absolute value irrespective of the state of any of the other elements. There have thus 
been efforts to develop group norms that penalize groups or subsets of elements of x together, such 
as group lasso [HTW15]. 

It turns out that there is a way to utilize a submodular function as the regularizer. Penalizing 
x via ||æ||o is identical to penalizing it via |V (x)| and note that m(A) = |A| is a modular function. 
Instead, we could penalize x via f(V(x)) for a submodular function f. Here, any element of x 
being non-zero would allow for a diminishing penalty of other elements of x being zero all according 
to the submodular function, and such cooperative penalties can be obtained via a submodular 
parameterization. Like when using the zero-norm ||2||p, this leads to the same combinatorial problem 
due to continuous optimization of x with a penalty term of the form f(V(x)). To address this, we can 
use the Lovász extension Ï (x) on a vector x. This function is convex but it is not a norm, but if we 
consider the construct defined as |||; = f(\2|), it can be shown that this satisfies all the properties 
of a norm for all non-trivial submodular functions [PG98; Bac+13] (i.e., those normalized submodular 
functions for which f(v) > 0 for all v). In fact, the group lasso mentioned above is a special case 
for a particularly simple feature-based submodular function (a sum of min-truncated cardinality 
functions). But in principle, the same submodular design strategies mentioned in Section 6.11.6 can 
be used to produce a submodular function to instantiate an appropriate convex structured norm for 


27 a given machine learning problem. 


23 6.11.12 Conclusions 


We have only barely touched the surface of submodularity and how it applies to and can benefit 
machine learning. For more details, see [Bil22] and the many references contained therein. Considering 
once again the innocuous looking submodular inequality, then very much like the definition of 
convexity, we observe something that belies much of its complexity while opening the gates to wide 
and worthwhile avenues for machine learning exploration. 
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Inference 


I Inference algorithms: an overview 


7.1 Introduction 


In the probabilistic approach to machine learning, all unknown quantities — be they predictions 
about the future, hidden states of a system, or parameters of a model — are treated as random 
variables, and endowed with probability distributions. The process of inference corresponds to 
computing the posterior distribution over these quantities, conditioning on whatever data is available. 

Let h represent the unknown variables, and D represent the known variables. Given a likelihood 
p(D|h) and a prior p(h), we can compute the posterior p(h|D) using Bayes’ rule: 


p(h)p(D|h) 

p(D) 
The main computational bottleneck is computing the normalization constant in the denominator, 
which requires solving the following high dimensional integral: 


p(D) = J p(D|h)p(h)dh (7.2) 


This is needed to convert the unnormalized joint probability of some parameter value, p(h,D), to a 
normalized probability, p(h|D), which takes into account all the other plausible values that h could 
have. Similarly, computing posterior marginals also requires computing integrals: 


p(hi\D) = f phi ha|D)dh (7.3) 


Thus integration is at the heart of Bayesian inference, whereas differentiation is at the heart of 
optimization. 

In this chapter, we give a high level summary of algorithmic techniques for computing (approximate) 
posteriors. We will give more details in the following chapters. Note that most of these methods 
are independent of the specific model. This allows problem solvers to focus on creating the best 
model possible for the task, and then relying on some inference engine to do the rest of the work — 
this latter process is sometimes called “turning the Bayesian crank”. For more details, see e.g., 
[Gel+14a; MKL21; MFR20]. 


p(h|D) = (7.1) 


7.2 Common inference patterns 


There are kinds of posterior we may want to compute, but we can identify 3 main patterns, as we 
discuss below. These give rise to different types of inference algorithm, as we will see in later chapters. 
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(a) 


Figure 7.1: Graphical models with (a) Global hidden variables for representing the Bayesian discriminative 
model p(yi:n, Oy|21:) = p(Oy) [IÀ p(yn|an; Oy); (b) Local hidden variables for representing the generative 
model p(x1:n, Z1:n|0) = [IX p(2n|0z)p(a@n|Zn, 92); (c) Local and global hidden variables for representing 
the Bayesian generative model p(a1:n, Z1:N,9) = p(0-)p(0z) Te D(2n|0z)p(an|Zn, 0x). Shaded nodes are 
assumed to be known (observed), unshaded nodes are hidden. 


7.2.1 Global latents 


The first pattern arises when we need to perform inference in models which have global latent 
variables, such as parameters of a model 0, which are shared across all N observed training cases. 
This is shown in Figure 7.la, and corresponds to the usual setting for supervised or discriminative 
learning, where the joint distribution has the form 


N 
PYN, Olen) = p0) TL P(Yn|@n, J (7.4) 


The goal is to compute the posterior p(0|x1.1,y1.n). Most of the Bayesian supervised learning 


22 models discussed in Part III follow this pattern. 


— 7.2.2 Local latents 


The second pattern arises when we need to perform inference in models which have local latent 


35 variables, such as hidden states z1:ņ; we assume the model parameters 0 are known. This is shown 
36 in Figure 7.1b. Now the joint distribution has the form 


N 
plær:n, 21:n19) = | [| p(anlen, Ox)p(2n|92) (7.5) 
n=1 


41 The goal is to compute p(Zn|£n,0) for each n. This is the setting we consider for most of the PGM 
42 inference methods in Chapter 9. 


If the parameters are not known (which is the case for most latent variable models, such as mixture 
models), we may choose to estimate them by some method (e.g., maximum likelihood), and then 


45 plugin this point estimate. The advantage of this approach is that, conditional on 0, all the latent 
46 variables are conditionally independent, so we can perform inference in parallel across the data. This 
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lets us use methods such as expectation maximization (Section 6.6.3), in which we infer p(Zn|@n, 0+) 
in the E step for all n simultaneously, and then update 0; in the M step. If the inference of zn 
cannot be done exactly, we can use variational inference, a combination known as variational EM 
(Section 6.6.6.1). 

Alternatively, we can use a minibatch approximation to the likelihood, marginalizing out z, for 
each example in the minibatch to get 


log p(D:|6:) = XC log z PlEn, =) (7.6) 


nEeD: 


where D; is the minibatch at step t. If the marginalization cannot be done exactly, we can use 
variational inference, a combination known as stochastic variational inference or SVI (Section 10.3.1). 
We can also learn an inference network qẹ(z|æ; 0) to perform the inference for us, rather than running 
an inference engine for each example n in each batch t; the cost of learning @ can be amortized 
across the batches. This is called amortized SVI, and is commonly used to train deep latent variable 
models such as VAEs (Section 21.2). 


7.2.3 Global and local latents 


The third pattern arises when we need to perform inference in models which have local and global 
latent variables. This is shown in Figure 7.1c, and corresponds to the following joint distribution: 


N 
p(@1:n,21:n,9) = p(Oz)p(8z) II P(Ln|Zn, Ox)p(Zn|4z) (7.7) 


This is essentially a Bayesian version of the latent variable model in Figure 7.1b, where now we model 
uncertainty in both the local variables z„ and the shared global variables 0. This approach is less 
common in the ML community, since it is often assumed that the uncertainty in the parameters 0 
is negligible compared to the uncertainty in the local variables z,,. The reason for this is that the 
parameters are “informed” by all N data cases, whereas each local latent zn is only informed by 
a single data point, namely æn. Nevertheless, there are advantages to being “fully Bayesian”, and 
modeling uncertainty in both local and global variables. We will see some examples of this later in 
the book. 


7.3 Exact inference algorithms 


In some cases, we can perform example posterior inference in a tractable manner. In particular, if 
the prior is conjugate to the likelihood, the posterior will be analytically tractable. In general, this 
will be the case when the prior and likelihood are from the same exponential family (Section 2.3). In 
particular, if the unknown variables are represented by 0, then we assume 
p(0) x exp(AgT (9) (7.8) 
P(yilO) x exp(Ai(yi)"T (8)) 
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where 7 (6) are the sufficient statistics, and A are the natural parameters. We can then compute the 
posterior by just adding the natural parameters: 


p(Olyi:n) = exp(ALT (8) (7.10) 
N 
Ax = Ao + X An(Yn) (7.11) 


See Section 3.2 for details. 

Another setting where we can compute the posterior exactly arises when the D unknown variables 
are all discrete, each with K states; in this case, the integral for the normalizing constant becomes a 
sum with K? terms. In many cases, K? will be too large to be tractable. However, if the distribution 
satisfies certain conditional independence properties, as expressed by a probabilistic graphical model 
(PGM), then we can write the joint as a product of local terms (see Chapter 4 and Chapter 4). This 
lets us use dynamic programming to make the computation tractable. See Chapter 9 for details. 


7.4 Approximate inference algorithms 


For most probability models, we will not be able to compute marginals or posteriors exactly, so we 
must resort to using approximate inference. There are many different algorithms, which trade off 
speed, accuracy, simplicity, and generality. We briefly discuss some of these algorithms below. We 
give more detail in the following chapters. 


7.4.1 MAP estimation 
The simplest approximate inference method is to compute the MAP estimate 

6 = argmax p(0 |D) = argmax log p(@) + log p(D|4) (7.12) 
and then to assume that the posterior puts 100% of its probability on this single value: 

p(@|D) ~ 5(@ — 8) (7.13) 


The advantage of this approach is that we can compute the MAP estimate using a variety of 


33 optimization algorithms, which we discuss in Chapter 6. The disavantages of the MAP approximation 
34 are discussed in Section 3.1.5. 


36 7.4.2 Grid approximation 


If we want to capture uncertainty, we need to allow for the fact that @ may have a range of possible 


values, each with non-zero probability. The simplest way to capture this property is to partition 
the space of possible values into a finite set of regions, call them r1,..., Tg, each representing a 
region of parameter space of volume A centered on Ox. This is called a grid approximation. The 


42 probability of being in each region is given by p(@ € r;,|D) ~ p,A, where 


Dk 
a (7.14) 
egret Dk 


Pr = P(D|Ox)p(Ox) (7.15) 
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True Posterior 4 === « Laplace approximation 


=== : Grid approximation True posterior 


(a) (b) 


Figure 7.2: Approximating the posterior of a beta-Bernoulli model. (a) Grid approximation using 20 grid 
points. (b) Laplace approximation. Generated by laplace approx beta binom.ipynb. 


As K increases, we decrease the size of each grid cell. Thus the denominator is just a simple numerical 
approximation of the integral 


p(D) = | p(D|)p(0)d0 


xQ 
M 
g 


(7.16) 


As a simple example, we will use the problem of approximating the posterior of a beta-Bernoulli 
model. Specifically, the goal is to approximate 


N 
p(6|D) « TL Berl Beta(1, 1) (7.17) 


n=1 


where D consists of 10 heads and 1 tail (so the total number of observations is N = 11), with 
a uniform prior. Although we can compute this posterior exactly using the method discussed in 
Section 3.2.1, this serves as a useful pedagogical example since we can compare the approximation to 
the exact answer. Also, since the target distribution is just 1d, it is easy to visualize the results. 

In Figure 7.2a, we illustrate the grid approximation applied to our 1d problem. We see that it 
is easily able to capture the skewed posterior (due to the use of an imbalanced sample of 10 heads 
and 1 tail). Unfortunately, this approach does not scale to problems in more than 2 or 3 dimensions, 
because the number of grid points grows exponentially with the number of dimensions. 


7.4.3 Laplace (quadratic) approximation 


In this section, we discuss a simple way to approximate the posterior using a multivariate Gaussian; 
this known as a Laplace approximation or quadratic approximation (see e.g., [TK86; RMC09J). 
Suppose we write the posterior as follows: 


1 
p(9|D) = oe (7.18) 
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where €(@) = — log p(0, D) is called an energy function, and Z = p(D) is the normalization constant. 
Performing a Taylor series expansion around the mode @ (i.e., the lowest energy state) we get 


E(0) ~ E(Ô) + (0 — 6)'g+ KG — ô\'H(0 — Ê) (7.19) 


where g is the gradient at the mode, and H is the Hessian. Since Ô is the mode, the gradient term is 
zero. Hence 


P(O, D) = e7E®) exp -10 — 6)" H(@ — 6) (7.20) 
p(@|D) = 5 0(0,D) = N (0|ĝ, H-t) (7.21) 
Z = EÔ) (2r) P/H? (7.22) 


The last line follows from normalization constant of the multivariate Gaussian. 

The Laplace approximation is easy to apply, since we can leverage existing optimization algorithms 
to compute the MAP estimate, and then we just have to compute the Hessian at the mode. (In high 
dimensional spaces, we can use a diagonal approximation.) 

In Figure 7.2b, we illustrate this method applied to our 1d problem. Unfortunately we see that it is 
not a particularly good approximation. This is because the posterior is skewed, whereas a Gaussian 
is symmetric. In addition, the parameter of interest lies in the constrained interval 8 € [0,1], whereas 
the Gaussian assumes an unconstrained space, @ € RP. Fortunately, we can solve this latter problem 
by using a change of variable. For example, in this case we can apply the Laplace approximation to 
a = logit(@). This is a common trick to simplify the job of inference. 


28 7.4.4 Variational inference 


In Section 7.4.3, we discussed the Laplace approximation, which uses an optimization procedure to 


31 find the MAP estimate, and then approximates the curvature of the posterior at that point based on 


the Hessian. In this section, we discuss variational inference (VTI), also called variational Bayes 
(VB). This is another optimization-based approach to posterior inference, but which has much more 
modeling flexibility (and thus can give a much more accurate approximation). 

VI attempts to approximate an intractable probability distribution, such as p(@|D), with one that 


36 is tractable, g(@), so as to minimize some discrepancy D between the distributions: 


q* = argmin D(q, p) (7.23) 
qEQ 


“© where Q is some tractable family of distributions (e.g., fully factorized distributions). Rather than 
* optimizing over functions q, we typically optimize over the parameters of the function q; we denote 
“ these variational parameters by w. 


It is common to use the KL divergence (Section 5.1) as the discrepancy measure, which is given by 


D(a, p) = Dea. (a(6|W) | »(6|D)) = f a(0}) 10g AE ao (7.24) 
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4 p(z): true posterior Q 10 
x - q(x): variational posterior T 
a g § 
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2 6 
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Figure 7.3: ADVI applied to the beta-Bernoulli model. (a) Approximate vs true posterior. (b) Negative ELBO 
over time. (c) Variational u parameter over time. (d) Variational o parameter over time. Generated by 
advi_ beta_ binom.ipynb. 


where p(@|D) = p(D|@)p(@)/p(D). The inference problem then reduces to the following optimization 
problem: 


y= is Dri (q(9|%) || p(@|P)) (7.25) 
0)p(0 
= argmin Ey(oly) {log qll) — log pane | (7.26) 
= a Egoy) [— log p(D|@) — log p(@) + log a(0|%)] + log p(D) (7.27) 
—— SS 


—L(p) 


Note that log p(D) is independent of p, so we can ignore it when fitting the approximate posterior, 
and just focus on maximizing the term 


L(Y) = Eq(6\y) [log p(D|@) + log p(@) — log q(4|b)| (7.28) 


Since we have Dz (q || p) > 0, we have L(a) < logp(D). The quantity log p(D), which is the log 
marginal likelihood, is also called the evidence. Hence L(%) is known as the evidence lower 
bound or ELBO. By maximizing this bound, we are making the variational posterior closer to the 
true posterior. (See Section 10.1 for details.) 
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Figure 7.4: Approximating the posterior of a beta-Bernoulli model using MCMC. (a) Kernel density estimate 
derived from samples from 4 independent chains. (b) Trace plot of the chains as they generate posterior 
samples. Generated by hmc_beta_binom.ipynb. 


We can chose any kind of approximate posterior that we like. For example, we may use a Gaussian, 
q(O|ab) = N(O|u, =). This is different from the Laplace approximation, since in VI, we optimize 
X, rather than equating it to the Hessian. If X is diagonal, we are assuming the posterior is fully 
factorized; this is called a mean field approximation. 

A Gaussian approximation is not always suitable for all parameters. For example, in our 1d 
example we have the constraint that 0 € [0,1]. We could use a variational approximation of the form 
q(6|) = Beta(@|a, b), where Y = (a,b). However choosing a suitable form of variational distribution 
requires some level of expertise. To create a more easily applicable, or “turn-key”, method, that works 
on a wide range of models, we can use a method called automatic differentiation variational 
inference or ADVI [Kuc+16]. This uses the change of variables method to convert the parameters 


27 to an unconstrained form, and then computes a Gaussian variational approximation. The method 
28 also uses automatic differentiation to derive the Jacobian term needed to compute the density of the 
29 transformed variables. See Section 10.3.5 for details. 


We now apply ADVI to our 1d beta-Bernoulli model. Let 0 = o(z), where we replace p(@|D) with 


31 q(z\p) = N(z|u,o), where Y = (u,o). We optimize a stochastic approximation to the ELBO using 


SGD. The results are shown in Figure 7.3 and seem reasonable. 


= 7.4.5 Markov Chain Monte Carlo (MCMC) 


36 Although VI is fast, it can give a biased approximation to the posterior, since it is restricted to a 
37 specific function form q E€ Q. A more flexible approach is to use a non-parametric approximation in 
38 terms of a set of samples, q(0) ~ $ 6(@ — 0°). This is called a Monte Carlo approximation. 
39 The key issue is how to create the posterior samples 0° ~ p(6|D) efficiently, without having to 
40 evaluate the normalization constant p(D) = f p(@,D)dé. 


For low dimensional problems, we can use methods such as importance sampling, which we 


42 discuss in Section 11.5. However, for high dimensional problems, it is more common to use Markov 
43 chain Monte Carlo or MCMC. We give the details in Chapter 12, but give a brief introduction 
44 here. 


The most common kind of MCMC is known as the Metropolis Hastings algorithm. The basic 


46 idea behind MH is as follows: we start at a random point in parameter space, and then perform a 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o e lw N e 


Io IN ls la le le Ie IE 


IS Is 


7.4. APPROXIMATE INFERENCE ALGORITHMS 


random walk, by sampling new states (parameters) from a proposal distribution q(0'|0). If q is 
chosen carefully, the resulting Markov chain distribution will satisfy the property that the fraction of 
time we visit each point in space is proportional to the posterior probability. The key point is that 
to decide whether to move to a newly proposed point 6’ or to stay in the curent point 6, we only 
need to evaluate the unnormalized density ratio 


POD) _ p(D|0)p(0)/p(D) _ p(P, 8) (7.29) 


p(O'|D) — p(D|8’)p(8')/p(D) pD, 8’) 


This avoids the need to compute the normalization constant p(D). (In practice we usually work with 
log probabilities, instead of joint probabilities, to avoid numerical issues.) 

We see that the input to the algorithm is just a function that computes the log joint density, 
log p(@,D), as well as a proposal distribution q(0'|0) for deciding which states to visit next. It is 
common to use a Gaussian distribution for the proposal, q(0'|0) = N (0'|0, cI); this is called the 
random walk Metropolis algorithm. However, this can be very inefficient, since it is blindly 
walking through the space, in the hopes of finding higher probability regions. 

In models that have conditional independence structure, it is often easy to compute the full 
conditionals p(04|6_a,P) for each variable d, one at a time, and then sample from them. This is 
like a stochastic analog of coordinate ascent, and is called Gibbs sampling (see Section 12.3 for 
details). 

For models where all unknown variables are continuous, we can often compute the gradient of the 
log joint, Vo log p(@, D). We can use this gradient information to guide the proposals into regions of 
space with higher probability. This approach is called Hamiltonian Monte Carlo or HMC, and 
is one of the most widely used MCMC algorithms due to its speed. For details, see Section 12.5. 

We apply HMC to our beta-Bernoulli model in Figure 7.4. (We use a logit transformation for 
the parameter.) In panel b, we show samples generated by the algorithm from 4 parallel Markov 
chains. We see that they oscillate around the true posterior, as desired. In panel a, we compute a 
kernel density estimate from the posterior samples from each chain; we see that the result is a good 
approximation to the true posterior in Figure 7.2. 


7.4.6 Sequential Monte Carlo 


MCMC is like a stochastic local search algorithm, in that it makes moves through the state space of 
the posterior distribution, comparing the current value to proposed neighboring values. An alternative 
approach is to use perform inference using a sequence of different distributions, from simpler to more 
complex, with the final distribution being equal to the target posterior. This is called sequential 
Monte Carlo or SMC. This approach, which is more similar to tree search than local search, has 
various advantages over MCMC, which we discuss in Chapter 13. 

A common application of SMC is to sequential Bayesian inference, in which we recursively 
compute (i.e., in an online fashion) the posterior p(0+|D1:t), where Dit = {(@n, Yn): n = 1: t} is all 
the data we have seen so far. This sequence of distributions converges to the full batch posterior 
p(@|D) once all the data has been seen. However, the approach can also be used when the data is 
arriving in a continual, unending stream, as in state-space models (see Chapter 29). The application 
of SMC to such dynamical models is known as particle filtering. See Section 13.2 for details. 
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Figure 7.5: Different approximations to a bimodal 2d distribution. (a) Local MAP estimate. (b) Parametric 
Gaussian approximation. (c) Correlated samples from near one mode. (d) Independent samples from the 
distribution. Adapted from Figure 2 of [PY14]. Used with kind permission of George Panadreou. 


7.4.7 Challenging posteriors 


In many applications, the posterior can be high dimensional and multimodal. Approximating such 
distributions can be quite challenging. In Figure 7.5, we give a simple 2d example. We compare 
MAP estimation (which does not capture any uncertainty), a Gaussian parametric approximation 
such as the Laplace approximation or variational inference, and a nonparametric approximation 
in terms of samples. If the samples are generated from MCMC, they are serially correlated, and 
may only explore a local model (see panel b). However, ideally we can draw independent samples 
from the entire support of the distribution, as shown in panel d. We may also be able to fit a local 
parametric approximation around each such sample (c.f., Section 17.3.9.1), to get a semi-parametric 
approximation to the posterior. 


2 7.5 Evaluating approximate inference algorithms 


There are many different inference algorithms each of which make different tradeoffs between speed, 
accuracy, generality, simplicity, etc. This makes it hard to compare them on an equal footing. 
However, a common approach is to evaluate the accuracy of the approximation as a function of 
compute time. 

We give an example of this in Figure 7.6, where we plot performance for several different inference 
algorithms applied to the following simple 1d model: 


p(z) = N(z|0, 100) (7.30) 
plzilz) = 0.5N (a;|z, 1) + 0.5M (2; |0, 10) (7.31) 


38 That is, each measurement x; is either a noisy copy of the hidden quantity of interest, z, or is an 
39 outlier coming from a uniform background noise model, approximated by a Gaussian with a large 
40 variance, M (0,10). The goal is to compute p(z|a1:.). (Tom Minka (who created this example) calls 
41 this the clutter problem.) 


Since this is a 1d problem, we can compute the exact answer using numerical integration. In 


43 Figure 7.6, we plot the accuracy of the posterior mean estimate using various approximate inference 


methods, namely: the Laplace approximation (Section 7.4.3), variational Bayes (Section 10.2.3), 


45 expectation propagation (Section 10.7), importance sampling (Section 11.5), and Gibbs sampling 
46 (Section 12.3). We see that the error smoothly decreases for the 3 deterministic methods (Laplace, 
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Figure 7.6: Accuracy vs compute time for different inference methods. Code source: https: // github. com/ 
tminka/ ep-clutter-example. From Figure 1 of [Min01b]. Used with kind permission of Tom Minka. 


VB and EP). For the 2 stochastic methods (IS and Gibbs), the error decreases on average, but the 
performance is quite noisy. 

In principle, the Monte Carlo methods will converge to a zero overall error, since they are unbiased 
estimators, but it is clear that their variance is quite high. The deterministic methods, by comparison, 
converge to a finite error, so they are biased, but their variance is much lower. We can trade off bias 
and variance by creating hybrid algorithms, that combine different techniques. We will see examples 
of this in later chapters. 

When we cannot compute the true target posterior distribution to compare to, evaluation is 
harder, but there are some approaches than can be used in certain cases. For some methods for 
assessing variational inference, see e.g., [Yao+18b; Hug+20]. For some methods for assessing Monte 
Carlo methods, see [CGR06; CTM17; GARI6]. For assessing posterior predictive distributions, see 
Section 14.2.3. 
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3 Inference for state-space models 


8.1 Introduction 


In this chapter, we consider the task of posterior inference in state-space models (SSM). We discuss 
SSMs in more detail in Chapter 29, but we can think of them as latent variable sequence models 
with the conditional independencies shown by the chain-structured graphical model Figure 8.1. The 
corresponding joint distribution has the form 


T T 
pyr, zurur) = | p(z1\t1) Trelew) roz wo) (8.1) 

t=2 t=1 
where z+ are the hidden variables at time t, y+ are the observations (outputs), and uw; are the optional 
inputs. 

Given the sequence of observations, and a known model, one of the main tasks with SSMs is to 
perform posterior inference about the hidden states; this is also called state estimation. At each 
time step t, there are multiple forms of posterior we may be interested in computing, including the 
filtering distribution p(z;|y1.,), the smoothing distribution p(z,|y1.r) (note that this conditions 
on future data T > t), and the fixed-lag smoothing distribution p(z:_¢|y1) (note that this 
infers @ steps in the past given data up to the present). We may also want to compute the predictive 
distribution h steps into the future: 


P(yetnlys:t) = XO P(yrrnlZetn)P(Ze+nlyrt) (8.2) 


Zt+h 


where the hidden state predictive distribution is 


P(zi+a|Yi) = 5 P(2e|Yr:e)P(Ze4112t)P(Ze¢21Ze41) i PlZt+h|Zt+r-1) (8.3) 
Zt:t+h—1 
See Figure 8.2 for a summary of these distributions. In addition to computing posterior marginals, 
we may want to compute the most probable hidden sequence, argmax,, ,. P(zı:r|yı:r), or sample 
sequences from the posterior, 21.7 ~ p(Z1-r|y1:7). We discuss algorithms for all these tasks later in 
this chapter. See also [Sar13; Tri21]. 


8.2 Inference based on the HMM filter 


In this section, we consider inference for SSMs where all the hidden variables are discrete. This 
is known as a hidden Markov model or HMM. This is discussed in detail in Section 29.2, but 
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Figure 8.1: A state-space model represented as a graphical model. z; are the hidden variables at time t, yt 
are the observations (outputs), and uz are the optional inputs. 


filtering 


prediction —h— 


fixed-lag 
smoothing 


—— 


fixed-interval 
smoothing 


=" Figure 8.2: The main kinds of inference for state-space models. The shaded region is the interval for which 
28 we have data. The arrow represents the time step at which we want to perform inference. t is the current 
29 time, T is the sequence length, £ is the lag and h is the prediction horizon. Used with kind permission of 
30 Peter Chang. 


33 in brief, it corresponds to the model in Equation (8.1) where p(z1 = k) = mp is the initial state 
34 distribution, p(z = k|z:-1 = j) = Aje is the state transition matrix (assumed stationary), and 
35 p(y:|z, = k) is some conditional distribution over observations. For notational simplicity, we assume 
36 all conditional disributions are stationary (the same over time), and we ignore inputs. Note, however, 
3° that the algorithms we discuss can be generalized to the non-stationary case, and also to 1d undirected 
38 graphical models, such as linear-chain CRFs (Section 4.4). 


~ 8.2.1 Example: casino HMM 


42 Before explaining the algorithms, we introduce a simple example of an HMM, which we will use to 
43 illustrate the methods. 


Suppose we are in a casino and observe a series of die rolls, y, € {1,2,...,6}. Being a keen-eyed 


45 statistician, we notice that the distribution of values is not what we expect from a fair die: it seems 
46 that there are occasional “streaks”, in which 6s seem to show up more often than other values. (This 
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Figure 8.3: The state transition matrix A and observation matrix B for the casino HMM. Adapted from 
[Dur+98, p54]. 


example is from [Dur+98], who call this setup the occasionally dishonest casino.) We would like 
to segment the time series into regimes corresponding to the use of a fair die and a loaded die. 
Let Ajk = p(z = k|z:-1 = j) be the state transition matrix, and By = p(y = |Z = k) be the 
observation matrix corresponding to a categorical distribution over values of the die face. Most of 
the time the casino uses a fair die, z = 1, but occasionally it switches to a loaded die, z = 2, for 
a short period. If z = 1 the observation distribution is a uniform categorical distribution over the 
symbols {1,...,6}. If z = 2, the observation distribution is skewed towards face 6. That is, 


plul = 1) = Cat(y:|[1/6, ..., 1/6]) (8.4) 
p(yelze = 2) = Cat(y:|[1/10, 1/10, 1/10, 1/10, 1/10, 5/10]) (8.5) 


See Figure 8.3 for an illustration of the state transition diagram A and the emission distributions B. 
If we sample from this model, we may generate data such as the following: 


hid: 1111111111222211111111111111111111112222222221222211111111111111111111 
obs: 1355534526553366316351551526232112113462221263264265422344645323242361 


Here obs refers to the observation and hid refers to the hidden state (1 is fair and 2 is loaded). In 
the full sequence of length 300, we find the empirical fraction of times that we observe a 6 in hidden 
state 1 to be 0.149, and in state 2 to be 0.472, which are very close to the expected fractions. (See 
casino_hmm.ipynb for the code.) 

Of course, when using an HMM in practice, we just see the observed sequence, and need to infer 
the hidden sequence, as we discuss below. 


8.2.2 Forwards filtering 


The Bayes filter is an algorithm for recursively computing the belief state p(z;|y1.,) given the 
prior belief from the previous step, p(z:—1|/Yy1:4-1), the new observation y+, and the model. This 
can be done using sequential Bayesian updating. For a dynamical model, this reduces to the 
predict-update cycle described below. 
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Figure 8.4: Inference in the dishonest casino. Vertical gray bars denote times when the hidden state 
corresponded to the loaded die. Blue lines represent the posterior probability of being in that state given 
different subsets of observed data. If we recover the true state exactly, the blue curve will transition at the 
same time as the gray bars. (a) Filtered estimates. (b) Smoothed estimates. (c) MAP trajectory. Generated 
by casino_hmm.ipynb. 


8.2.2.1 Prediction step 


The prediction step is just the Chapman-Kolmogorov equation: 


plaina) = f pelz 1)p(zt-1|Y1:+-1)dzt—1 (8.6) 


The prediction step computes the one-step-ahead predictive distribution for the latent state, which 


26 updates the posterior from the previous time step into the prior for the current step.! 
28 8.2.2.2 Update step 
The update step is just Bayes rule: 
1 
P(2t|Yit) = 7, Purl 2)p2elyre-1) (8.7) 
where the normalization constant is 
Z, = | wlyslev(zilyiadze = plure) (8.8) 
We can use the normalization constants to compute the log likelihood of the sequence as follows: 
T T 
log p(yı:r) = > log p(yilyit—1) = J. log Z, (8.9) 
t=1 t=1 


~~ where we define p(yi|yo) = p(y1). This is useful for computing the MLE of the parameters (see 
~ Section 29.4.2). 


= 1. The prediction step is not needed at t = 1 if p(z1) is provided as input to the model. However, if we just provide 
46 p(zo), we need to compute p(zi|y1.0) = p(21) by applying the prediction step. 
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8.2. INFERENCE BASED ON THE HMM FILTER 


8.2.2.3 Implementation 


When the latent states z, are discrete, the above integrals become sums which can be implemented 
as follows. First we define the belief state as a;(j) = p(z = jlyiz), the local evidence as A;(j) = 
p(y:|z = j), and the transition matrix A; j = p(z = j|z—1 = i). Then the predict step becomes 


aejt—1(J) £ p(z = j|yit-1) = 2 atı (i)Ai j (8.10) 


and the update step becomes 


1 1 ; 
œj) = z~ Oae- (9) = z~ 0) £ arı (i)Ai,j (8.11) 
where the normalization constant for each time step is given by 
Zi = pluly) -$i yil2e = plz = jlyr-1) = X Alal) (8.12) 
We can write the update equation in matrix-vector notation as follows: 
a, = normalize (A; ©(AT at_1)) (8.13) 


where © represents elementwise vector multiplication, and the normalize function just ensures its 
argument sums to one. (See Section 8.2.5 for more discussion on normalization.) 


8.2.2.4 Example 


Figure 8.4(a) illustrates filtering for the casino HMM, applied to a random sequence y;.7 of length 
T = 300. In blue, we plot the probability that the die is in the loaded (vs fair) state, based on 
the evidence seen so far. The gray bars indicate time intervals during which the generative process 
actually switched to the loaded die. We see that the probability generally increases in the right 
places. 


8.2.3 Backwards smoothing 


In the offline setting, we want to compute p(z;|y1.r), which is the belief about the hidden state at 
time t given all the data, both past and future. This is called (fixed interval) smoothing. We first 
perform the forwards or filtering pass, and then compute the smoothed belief states by working 
backwards, from right (time t = T) to left (t = 1), as we explain below. Hence this method is also 
called forwards filtering backwards smoothing or FFBS. 


8.2.3.1 Backwards recursion 


Suppose, by induction, that we have already computed p(z:+41|y1:7). We can convert this into a joint 
smoothed distribution over two consecutive time steps using 


p(Ze, Zt YT) = P(Ze|Ze41, YT) Plz YLT) (8.14) 
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To derive the first term, note that from the Markov properties of the model, and Bayes rule, we have 


DZ = 2141 = 9, YT) = (Zt = izi = J, Yt, Meee) (8.15) 
_ P(t = 4, 241 = jlyrt) (8.16) 

P(Ze+1 = jY) 
= P(Ze41 = j|zt = t)p(2 = ilYr:t) (8.17) 


P(Ze41 = Jlyi-t) 


Thus the joint distribution over two consecutive time steps is given by 


P(Ze41|2e)P(Ze| Yr) P(Ze41|Y1-7) (8.18) 
pP(Ze41|Yr-4) 


p(Zt, 241/17) = p(Zt|Ze41. Yre) plz lyr) = 


from which we get the new smoothed marginal distribution: 


pailut) = plzlu) f [eae arine dzt41 (8.19) 
P(2t41|Y1:t) 

= f læn zoraly EER de (8.20) 
P(Zt41|Y1t) 


Intuitively we can interpret this as follows: we start with the two-slice filtered distribution, 
p( Zt; Ze41|/Y14), and then we divide out the old p(z:41|y14) and multiply in the new p(z4,|y1-7), 
and then marginalize out 2:41. 


8.2.3.2 Implementation 


= Let us define the two-slice marginal using the following notation: 


Etili, j) = plz: = i, 2141 = J lyr) (8.21) 


We can then rewrite Equation (8.18) as follows: 


Et¢i (i, j) = aA (8.22) 
34 where 
arrel) = plz = jiy) = 50 AC, jJl’) (8.23) 


2 is the one-step-ahead predictive distribution. We can interpret the ratio in Equation (8.22) as dividing 
= out the old estimate of 241 given yi, namely a;41)4, and multiplying in the new estimate given 
= Yir, namely ¥41- 


Finally we can recover the one-slice posterior marginal by marginalizing: 


I ogie(9) 


J 


WO © pæ = ilar) = D> ni) = ali) ja: ee (8.24) 


46 We initialize the recursion using yr(j) = ar(j) = p(zr = j|yu-r). 
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8.2. INFERENCE BASED ON THE HMM FILTER 


8.2.3.3 Example 


In Figure 8.4(a-b), we compare filtering and smoothing for the casino HMM. We see that the posterior 
distributions when conditioned on all the data (past and future) are indeed smoother than when just 
conditioned on the past (filtering). 

To understand this behavior intuitively, consider a detective trying to figure out who committed a 
crime. As they move through the crime scene, his uncertainty is high until he finds the key clue; then 
he has an “aha” moment, his uncertainty is reduced, and all the previously confusing observations 
are, in hindsight, easy to explain. Thus we see that, given all the data (including finding the clue), 
it is much easier to infer the state of the world. 


8.2.4 The forwards-backwards algorithm 


In this section, we present a more common approach to smoothing in HMMs known as the forwards- 
backwards or FB algorithm [Rab89]. In the forwards pass, we compute az(j) = p(z: = j|Y1:t) as 
before. In the backwards pass, we compute the conditional likelihood 


BG) = p(yesirlze = j) (8.25) 


We then combine these using 


y(i) = p(z = I|Ytpurs Yi:t) x plz =j, Yir: TlY1:t) (8.26) 
= p(z = jlyrt)P(yerar |Z = j, Yat) = ali) (9) (8.27) 


In matrix notation, this becomes 
y; = normalize (œ © 64) (8.28) 


Note that the forwards and backwards passes can be computed independently, but both need access 
to the local evidence p(y:|z+). The results are only combined at the end. This is therefore called 
two-filter smoothing [Kit04]. 


8.2.4.1 Backwards recursion 


We can recursively compute the §’s in a right-to-left fashion as follows: 


Br-1 (i) = P(Ye-r|Zt-1 = 2) (8.29) 
= So pla = J, Yt, Yt+1:T|Zt-1 = 1) (8.30) 


J 
= So plypirlze = j Ze = (2 = j, Yelzi1 = i) (8.31) 
J 


= X plynarlz = j)plyi| zi = j, zza = 1) le = Jia = i) (8.32) 
j 


= > Bil9) At (3) A (8.33) 
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We can write the resulting equation in matrix-vector form as 


By = A(At © B;) (8.34) 
The base case is 
Br (i) = plyr+rlzr = i) = pOlzr =7) =1 (8.35) 


which is the probability of a non-event. 
Note that 6, is not a probability distribution over states, since it does not need to satisfy 
>; (j) = 1. However, we usually normalize it to avoid numerical underflow (see Section 8.2.5). 


8.2.4.2 Two-slice smoothed marginals 


We can compute the two-slice marginals using the output of the forwards-backwards algorithm as 
follows: 


P(2t, 2t41/Y1:7) = P(Ze, Ze41/Yrt, Your) (8.36) 
X P(Yt+1:T|Zt, Zt+1, Yrt)P(Zt, Zt+1lYr:t) (8.37) 
= p(Ye41:7 12141) P (Zs Zi+1lY1:t) (8.38) 
= P(Ye+i-7|2t4+1)P(Zel|Yr-t)P(Ze41 12) (8.39) 
= P(Ye+1, Yrt2:7 |2t-+1)P(Zt|Yrt)P(Zt41| 24) (8.40) 
= p(Yiyi|Zt+1)P(Yt+2:T|Zt+1; Yr+1)P(Ze|Yr1:t)P(Zt41| 22) (8.41) 
= P(Yr41|Ze41)P(Ye+2-7 2141) P(Zt|Y12)P(Zt4112Ze) (8.42) 


We can rewrite this in terms of the already computed quantities as follows: 


Er t1 (i, J) & Atya (j) Besa (jarli) Ai; (8.43) 
Or in matrix-vector form: 
E&i XAO [or (Arya © Bsa)’ | (8.44) 


Since ay x Ay © Qrt—1, we can also write the above equation as follows: 


E&i Xx AO (A: © @tj-1) © (A1 © Gea) | (8.45) 


22 This can be interpreted as a product of incoming messages and local factors, as shown in Figure 8.5. 
= In particular, we combine the factors ayj4—~1 = p(zt|Yrt-1), A = p(zi+ilzt), Ae x P(YelZt), ArH X 
= p(yrilzt+1), and Bipi X P(Yt+2:T|Zt+1) to get P(Zt, Zt+1, Yt, Yt+1, Yt+2:T|Y1:t—-1), which we can then 


22 normalize. 


~ 8.2.5 Numerically stable implementation 


42 In most publications on HMMs, such as [Rab89], the forwards message is defined as the following 
43 unnormalized joint probability: 


(j) = plz = j, Yr) = XQ) bs a41(1) Aaj (8.46) 
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Bisi 
< 


Y: Yt+1 


Figure 8.5: Computing the two-slice joint distribution for an HMM from the forwards messages, backwards 
messages, and local evidence messages. 


We instead define the forwards message as the normalized conditional probability 


ar(j) = plzi = ily) zi) [Eonia (8.47) 


The unnormalized (joint) form has several problems. First, it rapidly suffers from numerical 
underflow, since the probability of the joint event that (z; = j, Y1:+) is vanishingly small.? Second, 
it is less interpretable, since it is not a distribution over states. Third, it precludes the use of 
approximate inference methods that try to approximate posterior distributions (we will see such 
methods later). We therefore always use the normalized (conditional) form. 

Of course, the two definitions only differ by a multiplicative constant, since p(z; = jlyi+) = 
plz = j, Y1-)/P(yi+4) [Dev85]. So the algorithmic difference is just one line of code (namely the 
presence or absence of a call to the normalize function). Nevertheless, we feel it is better to present 
the normalized version, since it will encourage readers to implement the method properly (i.e., 
normalizing after each step to avoid underflow). 

In practice it is more numerically stable to compute the log probabilities 4 (j) = log p(y:|z: = j) 
of the evidence, rather than the probabilities A4(j) = p(y:|z, = j). We can combine the state 
conditional log likelihoods A¿(j) with the state prior p(z; = j|y14-1) by using the log-sum-exp trick, 
as in Equation (28.30). 


8.2.6 Time and space complexity 


It is clear that a straightforward implementation of the forwards-backwards algorithm takes O(K?T) 
time, since we must perform a K x K matrix multiplication at each step. For some applications, 
such as speech recognition, K is very large, so the O(K?) term becomes prohibitive. Fortunately, if 
the transition matrix is sparse, we can reduce this substantially. For example, in a sparse left-to-right 
transition matrix (e.g., Figure 8.7(a)), the algorithm takes O(T K) time. 

In some cases, we can exploit special properties of the state space, even if the transition matrix is 
not sparse. In particular, suppose the states represent a discretization of an underlying continuous 


2. For example, if the observations are independent of the states, we have p(z = j, yiit) = p(zt = j) [f p(y), 
which becomes exponentially small with t. 
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STATE 


OBSERVATION 


Figure 8.6: The trellis of states vs time for a Markov chain. Adapted from [Rab89]. 


state-space, and the transition matrix has the form A; j x p(z; — zi), where z; is the continuous 
vector represented by state i and p(u) is some scalar cost function, such as Euclidean distance. Then 
one can implement the forwards-backwards algorithm in O(T K log K) time. The key is to rewrite 
Equation (8.10) as a convolution, 


areh) = pla = JlYie—-1) = 2 at—ı( = 2 arı (ti) PG — i) (8.48) 


and then to apply the Fast Fourier Transform. (A similar transformation can be applied in the 
backwards pass.) This is very useful for models with large state spaces. See [FHK03] for details. 
We can also reduce inference to O(log T) time by using a parallel prefix scan operator that can 


= be run efficiently on GPUs. For details, see [HSGF 21]. 


In some cases, the bottleneck is memory, not time. In particular, to compute the posteriors y,, we 


= must store the fitered distributions a; for t = 1,...,7 until we do the backwards pass. It is possible 
= to devise a simple divide-and-conquer algorithm that reduces the space complexity from O(KT) to 
= O(K logT) at the cost of increasing the running time from O(K?T) to O(K?T logT). The basic 
— idea is to store a; and 8, vectors at a logarithmic number of intermediate checkpoints, and then 
= recompute the missing messages on demand from these checkpoints. See [BMR97; ZP00] for details. 


35 8.2.7 The Viterbi algorithm 


>, Lhe MAP estimate is (one of) the sequences with maximum posterior probability: 


Zy.-p = argmax p(21.-7|y1-7) = argmax log p(z1-T|y1-7) (8.49) 
21:7 21:T 
T 
= argmax log 7 (21) + log \1 (21) + 5 [log A(z1-1, zt) + log Ax (Z)] (8.50) 
PEE t=2 


44 This is equivalent to computing a shortest path through the trellis diagram in Figure 8.6, where 
45 the nodes are possible states at each time step, and the node and edge weights are log probabilities. 
46 This can be computed in O(T K?) time using the Viterbi algorithm [Vit67], as we explain below. 
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8.2.7.1 Forwards pass 


Recall the (unnormalized) forwards equation 


a.) = plz = j, Yit) = 5 P(Z1:t-1, Zt = j, Yit) (8.51) 


ZL yy Zt-1 


Now suppose we replace sum with max to get 


6(j) =_ max = p(21:t-1, 21 = j, Yre) (8.52) 
BY sce Zea 
This is the maximum probability we can assign to the data so far if we end up in state j. The key 
insight is that the most probable path to state j at time t must consist of the most probable path to 
some other state į at time t — 1, followed by a transition from i to 7. Hence 


öil) = M (j) [max d—1(8) Ai] (8.53) 


We initialize by setting 61(j7) = 7;A1(J). 
We often work in the log domain to avoid numerical issues. Let ô,(j) = —log6:(j), Ali) = 
— log p(yz|2e = j), A’(é, j) = — log p(z: = j|zt-1 = i). Then we have 


SLC) = NG) + [minda +A) (8.54) 


We also need to keep track of the most likely previous (ancestor) state, for each possible state 
that we end up in: 


alj) £ argmax ô+—1(i)A;i j = argmin ô;_1 (i) + A'(i, j) (8.55) 


That is, a:(j) stores the identity of the previous state on the most probable path to z, = j. We will 
see why we need this in Section 8.2.7.2. 


8.2.7.2 Backwards pass 


In the backwards pass, we compute the most probable sequence of states using a traceback procedure, 
as follows: zf = ay+41(2{,1), where we initialize using zž = arg max; ôr (i). This is just following the 
chain of ancestors along the MAP path. 

If there is a unique MAP estimate, the above procedure will give the same result as picking 
2, = argmax, (j), computed by forwards-backwards, as shown in [WF 01b]. However, if there are 
multiple posterior modes, the latter approach may not find any of them, since it chooses each state 
independently, and hence may break ties in a manner that is inconsistent with its neighbors. The 
traceback procedure avoids this problem, since once z+ picks its most probable state, the previous 
nodes condition on this event, and therefore they will break ties consistently. 


8.2.7.3 Example 


In Figure 8.4(c), we show the Viterbi trace for the casino HMM. We see that, most of the time, the 
estimated state corresponds to the true state. 
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C1 0.5 0 0 
c2 0.2 0 0 
C3 0.3 0.2 0 
C4 0 0.7 0.1 
C5 0 0.1 0 
C6 0 0 0.5 
C7 0 0 0.4 
(a) (b) 


Figure 8.7: Illustration of Viterbi decoding in a simple HMM for speech recognition. (a) A 3-state HMM 
for a single phone. We are visualizing the state transition diagram. We assume the observations have been 
vector quantized into 7 possible symbols, C1,...,C7. Each state S1,52,53 has a different distribution over 
these symbols. Adapted from Figure 15.20 of [RN02]. (b) Illustration of the Viterbi algorithm applied to this 
model, with data sequence C1, C3,C4,Ce. The columns represent time, and the rows represent states. The 
numbers inside the circles represent the 6:(j) value for that state. An arrow from state i at t — 1 to state 
j att is annotated with two numbers: the first is the probability of the i > j transition, and the second is 
the probability of generating observation y: from state j. The red lines/ circles represent the most probable 
sequence of states. Adapted from Figure 24.27 of [RN95]. 


In Figure 8.7, we give a detailed worked example of the Viterbi algorithm, based on [Rus+95]. 
Suppose we observe the sequence of discrete observations y1:4 = (C1, C3, C4, C6), representing 
codebook entries in a vector-quantized version of a speech signal. The model starts in state zı = Sj. 
The probability of generating xı = C1 in zı is 0.5, so we have 6,(1) = 0.5, and ô (i) = 0 for all other 
states. Next we can self-transition to Sı with probability 0.3, or transition to S2 with proabability 0.7. 
If we end up in Sj, the probability of generating £2 = C3 is 0.3; if we end up in S2, the probability 
of generating x2 = C3 is 0.2. Hence we have 


69(1) = 61 (1)A(1, 1)A2(1) = 0.5 - 0.3 - 0.3 = 0.045 (8.56) 
59(2) = 61(1)A(1, 2)A2(2) = 0.5 -0.7- 0.2 = 0.07 (8.57) 


= Thus state 2 is more probable at t = 2; see the second column of Figure 8.7(b). The algorithm 
= continues in this way until we have reached the end of the sequence. One we have reached the end, 
= we can follow the red arrows back to recover the MAP path (which is 1,2,2,3). 


For more details on HMMs for automatic speech recognition (ASR) see e.g., [J M08]. 


— 8.2.7.4 Time and space complexity 


The time complexity of Viterbi is clearly O(K?T) in general, and the space complexity is O(KT), 


43 both the same as forwards-backwards. If the transition matrix has the form A; j x p(z; — zi), where 
44 zi is the continuous vector represented by state i and p(u) is some scalar cost function, such as 
45 Euclidean distance, we can implement Viterbi in O(T K) time, by using the generalized distance 
46 transform to implement Equation (8.54). See [FHK03; FH12] for details. 
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8.2.7.5 N-best list 


There are often multiple paths which have the same likelihood. The Viterbi algorithm returns one of 
them, but can be extended to return the top N paths [SC90; NGO1]. This is called the N-best list. 
Computing such a list can provide a better summary of the posterior uncertainty. 

In addition, we can perform discriminative reranking [CK05] of all the sequences in Ly, based 
on global features derived from (y1:7, 21:7). This technique is widely used in speech recognition. For 
example, consider the sentence “recognize speech”. It is possible that the most probable interpretation 
by the system of this acoustic signal is “wreck a nice speech”, or maybe “wreck a nice beach” (see 
Figure 34.1). Maybe the correct interpretation is much lower down on the list. However, by using 
a re-ranking system, we may be able to improve the score of the correct interpretation based on a 
more global context. 

One problem with the N-best list is that often the top N paths are very similar to each other, 
rather than representing qualitatively different interpretations of the data. Instead we might want to 
generate a more diverse set of paths to more accurately represent posterior uncertainty. One way to 
do this is to sample paths from the posterior, as we discuss in Section 8.2.8. Another way is to use a 
determinantal point process (Section 31.9.5) which encourages points to be diverse [Bat+12; ZA12]. 


8.2.8 Forwards filtering, backwards sampling 


Rather than computing the single most probable path, it is often useful to sample multiple paths from 
the posterior: zf.p ~ p(Z1-r|y1-r). We can do this by modifying the forwards filtering backwards 
smoothing algorithm from Section 8.2.3, so that we draw samples on the backwards pass, rather than 
computing marginals. This is called forwards filtering backwards sampling (also sometimes 
unfortunately abbreviated to FFBS). In particular, note that we can write the joint from right to left 
using 


pP(Z1:7ly:r) = p(Zr|y.:7)p(2r-1|27, Y1:7)P(Zr—2|27-1, 275 YT) p(21|Z2,Z35 yr) (8.58) 
1 


= p(zr|y1:-7) I] P(2t|2t41, Vr) (8.59) 
t=T-1 


Thus at step t we sample zf from p(z;|27,1, y1:r) given in Equation (8.17). 


8.3 Inference based on the Kalman filter 


In this section, we discuss inference in SSMs where all the distributions are linear Gaussian. This is 
called a linear Gaussian state space model (LG-SSM) or a linear dynamical system (LDS). 
We discuss such models in detail in Section 29.6, but in brief they have the following form: 


D(Zt\Zt-1, Ut) = N (2:|F 2-1 + Brug + by, Qt) (8.60) 
P(Y Ze, Ur) = N (yi Hiz: + Diu: + di, R:) (8.61) 
where z; € RY is the hidden state, y, € R^” is the observation, and u, € R^! is the input. (We 


have allowed the parameters to be time-varying, for later extensions that we will consider.) We often 
assume the means of the process noise and observation noise (i.e., the bias or offset terms) are zero, 
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so b; = 0 and d; = 0. In addition, we often have no inputs, so B; = D; = 0. In this case, the model 
simplifies to the following:? 

p(2t|Zt-1) = N (2|F 22-1, Qt) (8.62) 

p(yt|22) = N (yi |Hizt, Re) (8.63) 


See Figure 8.1 for the graphical model. 

Note that an LG-SSM is just a special case of a Gaussian Bayes net (Section 4.2.3), so the 
entire joint distribution p(y1.7, Z1.7|u1-7) is a large multivariate Gaussian with N,N,T dimensions. 
However, it has a special structure that makes it computationally tractable to use, as we show below. 
In particular, we will discuss the Kalman filter and Kalman smoother, that can perform exact 
filtering and smoothing in O(T N3) time. 


8.3.1 Examples 


Before diving into the theory, we give some motivating examples. 


8.3.1.1 Tracking and state estimation 


A common application of LG-SSMs is for tracking objects, such as airplanes or animals, from noisy 
measurements, such as radar or cameras. For example, suppose we want to track an object moving 
in 2d. (We discuss this example in more detail in Section 29.7.1.) The hidden state z; encodes the 
location, (x11, X42), and the velocity, (41,1), of the moving object. The observation y; is a noisy 
version of the location. (The velocity is not observed but can be inferred from the change in location.) 


We assume that we obtain measurements with a sampling period of A. The new location is the old 
26 location plus A times the velocity, plus noise added to all terms: 
1 0A 0 
01 0A 
= 001 0 Zt-1 + Ot (8.64) 
00 0 1 
F 
where q ~ N (0, Q+). The observation extracts the location and adds noise: 
1 0 0 0 
Yt = f 1 0 q Zt HT (8.65) 
——— 
H 


— where r; ~ N (0, Rt). 


Our goal is to use this model to estimate the unknown location (and velocity) of the object given 


~= the noisy observations. In particular, in the filtering problem, we want to compute p(zi|Y1:+) in 
~ a recursive fashion. Figure 8.8(b) illustrates filtering for the linear Gaussian SSM applied to the 
~ noisy tracking data in Figure 8.8(a) (shown by the green dots). The filtered estimates are computed 


& 3. Our notation is similiar to [Sar13], except he writes p(a,|ap~—1) = N (®£k|Ak-1£k-1, Qk—1) instead of p(zt|zt—1) = 
46 N (zt|Fizt—1, Qt), and p(yk|æk) = N (yr |Hexx, Re) instead of p(yt|zt) = N (ye|Heze, Re). 
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=- true state o 20 
emissions 


O observed 20 
true state 
18 | — filtered means 18 


O observed 
true state 
— smoothed means 


16; O 


(a) (b) (c) 


Figure 8.8: Illustration of Kalman filtering and smoothing for a linear dynamical system. (a) Observations 
(green cirles) are generated by an object moving to the right (true location denoted by blue squares). (b) Results 
of online Kalman filtering. Red cross is the posterior mean, circles are 95% confidence ellipses derived from 
the posterior covariance. (c) Same as (b), but using offline Kalman smoothing. The MSE in the trajectory 
for filtering is 3.13, and for smoothing is 1.71. Generated by kf_ tracking.ipynb. 


using the Kalman filter algorithm described in Section 8.3.2. The red line shows the posterior 
mean estimate of the location, and the black circles show the posterior covariance. We see that the 
estimated trajectory is less noisy than the raw data, since it incorporates prior knowledge about how 
the data was generated. 

Another task of interest is the smoothing problem where we want to compute p(z;|y1.r) using an 
offline dataset. Figure 8.8(c) illustrates smoothing for the LG-SSM, implemented using the Kalman 
smoothing algorithm described in Section 8.3.3. We see that the resulting estimate is smoother, and 
that the posterior uncertainty is reduced (as visualized by the smaller confidence ellipses). 

The disadvantage of the above smoothing method is that we have to wait until all the data has 
been observed before we start performing inference. Fixed lag smoothing is a useful compromise 
between online and offline estimation; it involves computing p(zt—-e|Y1:t), where £ > 0 is called the 
lag. This gives better performance than filtering, but incurs a slight delay. By changing the size of 
the lag, we can trade off accuracy vs delay. 


8.3.1.2 Online Bayesian linear regression (recursive least squares) 


In Section 29.7.2 we discuss how to use the Kalman filter to recursively compute the exact posterior 
p(w|D1:+) for a linear regression model in an online fashion. This is known as the recursive least 
squares algorithm. The basic idea is to treat the latent state to be the parameter values, z; = w, 
and to define the non-stationary observation model as p(y:|2:) = N (yele! 21,07), and the dynamics 
model as p(z2|Z:-1) = N (2¢|Z4-1, OD. 


8.3.1.3 Time series forecasting 


In Section 29.12, we discuss how to use Kalman filtering to perform time series forecasting. 
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8.3.2 The Kalman filter 


The Kalman Filter (KF) is an algorithm for exact Bayesian filtering for linear Gaussian state space 
models. The resulting algorithm is the Gaussian analog of the HMM filter in Section 8.2.2. The belief 
state at time t is now given by p(z:|Yyi4) = N (zt|Hijt Zrt), where we use the notation pyy and Xy 
to represent the posterior mean and covariance given yj.,.4 Since everything is Gaussian, we can 
perform the prediction and update steps in closed form, as we explain below (see Section 8.3.2.4 for 
the derivation). 


8.3.2.1 Predict step 


The one-step-ahead prediction for the hidden state, also called the time update step, is given by 


the following: 
P(2t|Y1t-1, Ut) = N (2: |Mejt—1 Dejt-1) (8.66) 
Meta = Fimy—ipe_-1 + Beu: + by (8.67) 
Dye = FXp- Fi + Q (8.68) 


8.3.2.2 Update step 


The update step (also called the measurement step) can be computed using Bayes rule, as follows: 


P(Zt|Yrit, Urt) = N (ztl Mee, Det) ( 
mM: = Hihat- + Dru: + di ( 
S; = H; X1 H} +R; ( 
K: = X-H; S7" ( 
et = Yi — Mı (8.73 
Mat = Htjt—1 + Kye, ( 
Det = Vye-1 — KH Xie- ( 
= Dyt-1 — K;S;K] ( 


— where my; is the expected observation, e+, is the residual error or innovation term, and K; is the 
— Kalman gain matrix. 


Note that, by using the matrix inversion lemma, the Kalman gain matrix can also be written as 
K; = Dip- H; (H, Sq Hy +R) = (Sy, + HP Ry 'H,) ‘Hy R;" (8.77) 


= This is useful if Ry! is precomputed (e.g., if it is constant over time) and Ny > Nz. 


|, 4: We represent the mean and covariance of the filtered belief state by p14), and Zijt, but some authors use the notation 
—= my and P; instead. We represent the mean and covariance of the smoothed belief state by HiT and Xr, but some 
49 authors use the notation m and PẸ instead. Finally, we represent the mean and covariance of the one-step-ahead 
46 posterior predictive distribution, p(zt|Y1:t+—1), by Htjt—1 and %4)4~-1, whereas some authors use m, and P, instead. 
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8.3.2.3 Posterior predictive 


The one-step-ahead posterior predictive density for the observations can be computed as follows. (We 
ignore inputs and bias terms, for notational brevity.) First we compute the one-step-ahead predictive 
density for latent states: 


plaina) = f pelz 1)p(zı 1|Y1t 1)dzt 1 = N (zt|Hit-1 Dele-1) (8.78) 


= N (zi|Fite ij Fitri- F} + Qi) = N (zil hie Sees) (8.79) 


Then we convert this to a prediction about observations by marginalizing out z+: 


P( Yel Y1:t—1) = [ple zelyiea)de = J plurleedpleelnea)de = N (yim, S+) (8.80) 


This can also be used to compute the log-likelihood of the observations: The normalization constant 
of the new posterior can be computed as follows: 


T T 
logp(y:r) = > log p(y:|yi-1) = J log Z, (8.81) 
t=1 


t=1 


where we define p(y:|yo) = p(y1). This is just a sum of the log probabilities of the one-step-ahead 
measurement predictions, and is a measure of how “surprised” the model is at each step. 

We can generalize the prediction step to predict observations K steps into the future by first 
forecasting K steps in latent space, and then “grounding” the final state into predicted observations. 
(This is in contrast to an RNN (Section 16.3.4), which requires generating observations at each step, 
in order to update future hidden states.) 


8.3.2.4 Derivation 


In this section we derive the Kalman filter equations, following [Sar13, p57]. The results are a 
straightforward application of the rules for manipulating linear Gaussian systems, discussed in 
Section 2.2.6. 

First we derive the prediction step. From Equation (2.81), the joint predictive distribution for 
states is given by 


P(2t-1, 2t|Yrt-1) = p(t|2t-1)p(Zt-1|Yrt-1) (8.82) 
= N (2:|Fe2t-1, Qu)N (21-1 |Me-aje—1 Eiai) (8.83) 
=N e lu’, =") (8.84) 
Zt 
where 
T 

1_ | Be-1jt-1 yy ie 21 ¢—1 Fy ) 8.85 
H Cree , Ga FiXa- F] +Q (aal 

Hence the marginal predictive distribution for states is given by 
plzilYrt-1) = N (z| Fim it-1 Fe Desi Fi + Q:) = N (2t|Meje—1> Delt) (8.86) 
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Now we derive the the measurement update step. The joint distribution for state and observation 
is given by 
P(Ze, YelYrt—1) = P(YelZe)P(Ze|Yrt—1) (8.87) 
= N(y;|He2:, Re) N (21|Meje—15 Xije—1) (8.88) 
=N (G) lu", =") (8.89) 
Ut 
where 
T 
"— Melt-1 yy Xit- Xie- H; 8.90 
H A i o HDH] + Re (8.90) 
Finally, we convert this joint into a conditional using Equation (2.39) as follows: 
P(2t|Yts Yrt-1) = N (2t| hit» Eije) (8.91) 
Met = Mejt—1 + Det—-1 Hy (He £a- Hy + Ri) [ye - He by 1] (8.92) 
= Myr + Kly: — Hihi] (8.93) 
Die = Vee — Bie- H; (HDi H; + Re) H: Eye (8.94) 
= Sy — KH Xie- (8.95) 
where 
S; = H&H] + Re (8.96) 
K; = Xie- H; 8; * (8.97) 


= 8.3.2.5 Abstract formulation 


We can represent the Kalman filter equations much more compactly by defining various functions that 
create and manipulate jointly Gaussian systems, as in Section 2.2.6. In particular, suppose p(z) = 


33 N(z| ñ, 3), and p(y|z) = N(y|Az+b,Q). Then the joint is given by p(z, y) = N (z, y|ñ, ©), where 


ju = (jt,m) and © = (5, C;C',S), where the terms m, S and C are given by LinGaussPredict 


35 in Algorithm 7. From this, the posterior is given by p(z|y) = N(z| @, 5), where the posterior 
36 parameters are given by GaussCondition in Algorithm 7. 


We can now apply these functions to derive Kalman filtering as follows. In the prediction step, we 
38 compute 
_ Mei |t-1 Xii- Bt-t- 
Zt—1, Z|Yrt-1) =N , i 8.98 
Plein) Hel iit—1) (( Mejt-1 ) Gree Xij- ( ) 
(Mejt—1 Xij- Ve-1,¢/t) = LinGaussPredict(p,_1)4_1, Dy-1t-15 Fe, Berur + bi, Qi) (8.99) 


44 from which we get the marginal distribution 


p(2t|yrt—1) =N (Meje—1 Beje—1) (8.100) 
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Algorithm 7: Functions for a linear Gaussian system. 
def LinGaussPredict(ji, ©, A, b, Q) : 
m=Af+b 
S=92+AX AT 
C =% Al 
Return (m, S, C) 


a A WOW N e 


def GaussCondition(fi, ¥,m,S,C, y): 
K = CS7! 

=ñ +K(y - m) 

=% -KSK 

L= log N (y|m, S) 

Return (fi, $, 0 


O ONG 


1 


(=) 


1 


m 


N 


12 def LinGaussUpdate(ň, Š, A,b, Q, y) ; 
13 (m,S, C) = LinGaussPredict (ji, ©, A, b, Q) 
14 (fi, a) = GaussCondition(ji, ©, m, S, C, y) 


15 Return (f, Ñ, £) 


In the update step, we compute the joint distribution 


— X- C 
P(Zt, YtlYrt—1) =N (es ‘) , ( cr . a) (8.101) 
(mi, S+, C+) = LinGaussPredict (44,1, Sij¢-1, Hr, Dus + di, Re) (8.102) 


We then condition this on the observations to get the posterior distribution 


P(2t\Ye, Yrt-1) = pizdy) = N (Mets Eije) (8.103) 
(Hiji Delt, 4) = GaussCondition(p,),_1, Vet—1, Mz, Sz, Ce, yt) (8.104) 


We can combine the two parts of the update operation by defining the helper function LinGaussUpdate 
in Algorithm 7. The overall KF algorithm is shown in Algorithm 8. 


Algorithm 8: Kalman filter. 


1 def KF(F 1-7, Bur, bir, Qrun, Har, Dir, diir, Riv, U7, Y1:7; Mojo» Xojo) : 

2 foreach t = 1:7 do 

3 (Heit Xilt-1 -) = LinGaussPredict (4,1—1; Xii- Fr, Biu: + bi, Qi) 
4 (Hiji Dee, lt) = LinGaussUpdate(Hit—1; Xije-1, He, Dru + dr, Re) 


5 Return (Mies Delt )tm1> Dai l 
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8.3.2.6 Numerical issues 


In practice, the Kalman filter can encounter numerical issues. One solution is to use the information 
filter, which recursively updates the natural parameters of the Gaussian, Ay; = Lie and ny, = 
Athis instead of the mean and covariance (see Section 8.3.4). Another solution is the square root 
filter, which works with the Cholesky or QR decomposition of X,+, which is much more numerically 
stable than directly updating X,+. These techniques can be combined to create the square root 
information filter (SRIF) [May79]. (According to [Bic06], the SRIF was developed in 1969 for 
use in JPL’s Mariner 10 mission to Venus.) In [Tol22] they present an approach which uses QR 
decompositions instead of matrix inversions, which can also be more stable. 


8.3.3 The Kalman (RTS) smoother 


In Section 8.3.2, we described the Kalman filter, which sequentially computes p(z;|y1:1) for each t. 
This is useful for online inference problems, such as tracking. However, in an offline setting, we can 
wait until all the data has arrived, and then compute p(z;|y1:7). By conditioning on past and future 
data, our uncertainty will be significantly reduced. This is illustrated in Figure 8.8(c), where we see 
that the posterior covariance ellipsoids are smaller for the smoothed trajectory than for the filtered 
trajectory. 

We now explain how to compute the smoothed estimates, using an algorithm called the RTS 
Smoother or RTSS, named after its inventors, Rauch, Tung and Striebel [RTS65]. It is also 
known as the Kalman smoothing algorithm. The algorithm is the linear-Gaussian analog to the 
forwards-filtering backwards-smoothing algorithm for HMMs in Section 8.2.3. 


=" 8.3.3.1 Algorithm 


In this section, we derive the RTS smoother, following [Sar13, p136]. As in the derivation of the 


30 Kalman filter in Section 8.3.2.4, we make heavy use of the rules for manipulating linear Gaussian 
31 systems, discussed in Section 2.2.6. 


The joint filtered distribution for two consecutive time slices is 


P(t, Ze41|Yr:t) = P(Ze41|2e)P(ZelYrt) = N (Ze41[F 222, QN (ztl Hajt Eie) (8.105) 


=N (Z) mi vı) (8.106) 


=à Pu ) v= ( Bae EFi 8.107 
“ Guay j 6s FEF] + Q: (aem 


By the Markov property for the hidden states we have 


plzi|Zt+1, Yr) = p(2t| 241; Yi:t; Yt41:7T) = p(2t|Ze41; Yi:t) (8.108) 
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and hence by conditioning the joint distribution p(zt, z:41|y1-4) on the future state we get 


P(24|Zt41; Yur) = N (21\172 (2141), V2) 
Dealt = F&F] + Q+ 


Ty 
Gi = ZF: Deut 


Ma (Z141) = Mee + Ge(Ze41 — Fimyy) 
Vo = Dee — GiBi G] 


where G; is the backwards Kalman gain matrix. 
The joint smoothed distribution of two consecutive time slices is 


P(Ze41, 2e|Yur) = p(2t41| YT) P(Zt|Ze41, Yur) 
= N(2e41|byrs Eir) N (2: |mao(Ze41), V2) 


=A) ree) 


where 


m= ( Hitit ) = ( XiT Erap ) 
Hit + Ge(Megayr — Febye) Grr GrBpy1jrG} + Ve 


From this, we can extract the smoothed marginal 


P(Z\yir) = N (zil Mer Xir) 
Myr = Hije + Gi (Hiir = Fy fy) 
Ear = Gbr Gl + V2 = Dye + G(r — FXF] — Qa) GI 


Algorithm 9: Kalman smoother. 

1 def KS(Fi-7, Bir, bir, Qur, Bir, Dir, diir, Ra:T, U1:T, Y1:7; Hojo; Volo) : 

2 (Hits Ee) = KE (Fir, Bur, bir, Qir, Hir, Dir, dir, Ra:T, U1:T, YT, Hojo, Zolo) 
3 foreach t=T-—1:1 do 

(Hiili: Eije Yrtit) = LinGaussPredict (4, Eije Feyi, Beri ttp + biti, Qi41) 
beli tae = (Hije Mipil Xio Bipi Xt, t+1lt) belyijr = (Mizar Ziar) 

(Mer, Er) = GaussSoftCondition(bel; 441), bely+1\r) 


7 Return (Mies Delt )tn1> Dai l 


ao na A 


(8.117) 


(8.118) 
(8.119) 
(8.120) 


The overall algorithm is shown in Algorithm 9, where the GaussSoftCondition function is 


defined in Algorithm 10. 
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Algorithm 10: Updating a joint Gaussian p(z, 2141|y1-4) by conditioning on p(z:41|y1-r). 
1 def GaussSoft Condition(bel; ¢ +1), belir) 

(Hilts Mitt Dee, Vey tes Dyt+1\t) = belij 

(Hiri Perit) = bely 

e i; 

Myr = Hte + Gi (Hiyat a Hirie) 
Ear = Xie + G(r — Ei) GI 
Return (HiT: Er) 


No a A UUN 


8.3.3.2 Two-filter smoothing 


Note that the backwards pass of the Kalman smoother does not need access to the observations, Y1:T, 


~ but does need access to the filtered belief states from the forwards pass, p(Z:|yi:t) = N (zt| Hits Delt): 
= There is an alternative version of the algorithm, known as two-filter smoothing [FP69; Kit04], in 
= which we compute the forwards pass as usual, and then separately compute backwards messages 
= p(Yirr|zt) X N (zil hii Eje) similar to the backwards filtering algorithm in HMMs (Section 8.2.4). 


However, these backwards messages are are conditional likelihoods, not posteriors, which can cause 


= numerical problems. For example, consider t = T; in this case, we need to set the initial covariance 
22 matrix to be Xb = ool, so that the backwards message has no effect on the filtered posterior (since 
= there is no evidence beyond step T). This problem can be resolved by working in information form. 
= An alternative approach is to generalize the two-filter smoothing equations to ensure the likelihoods 
= are normalizable by multiplying them by artificial distributions [BDM10]. 


In general, the RTS smoother is preferred to the two-filter smoother, since it is more numerically 


= stable, and it is easier to generalize it to the nonlinear case. 


— 8.3.3.3 Time and space complexity 


31 In general, the Kalman smoothing algorithm takes O(N? + N2 + N,N.) per step, where there 
32 are T steps. This can be slow when applied to long sequences. In [SGF21], they describe how to 
33 reduce this to O(log T) steps using a parallel prefix scan operator that can be run efficiently on 
34 GPUs. In addition, we can reduce the space from O(T), to O(log T) using the same algorithm as in 
35 Section 8.2.6. 


8.3.3.4 Forwards filtering backwards sampling 


39 To draw posterior samples from the LG-SSM, we can leverage the following result (derived in [Sar13, 
49 P137]): 


P(2elZt41, yur) =N (zti, £) 
My = Hie + Ge(Ze41 — Fihi) 
> = Xi GEG = Xi EF, Vey ye EtG 
= D e(I - F; G} ) 
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1 

2 where G; is the backwards Kalman gain defined in Equation (8.111). 

3 

4 8.3.4 Information form filtering and smoothing 

5 

6 This section was written by Giles Harper-Donnelly. 

I In this section, we derive the Kalman filter and smoother algorithms in information form. We will 
8 see that this is the “dual” of Kalman filtering / smoothing in moment form. In particular, while 
2? computing marginals in moment form is easy, computing conditionals is hard (requires a matrix 
10 inverse). Conversely, for information form, computing marginals is hard, but computing conditionals 
11 is easy. 

12 

13 

q4 8.3.4.1 Filtering: algorithm 

15 The predict step has a similar structure to the update step in moment form. We start with the prior 
16 p(Zt-1|Yrt-1, Urt-1) = Nelzt-1lMm -it-1 At—1Je—-1) and then compute 

17 

18 P(2t|Y1t-1, Ur) = Ne(2t|M4e-1 Aut-1) (8.125) 
19 = 

za M; = Aip- +F] Q7 F, (8.126) 
21 J; = Q7 'F;M7' (8.127) 
E Aei = Qr- Qr F(A- HFF) FQ (8.128) 
24 =Q- JF] Q7' (8.129) 
= =Q -JMJ (8.130) 
m Nit- = Jini- + Ajt-1 (Bru + by), (8.131) 
28 where J; is analogous to the Kalman gain matrix in moment form Equation (8.72). From the 
29 matrix inversion lemma Equation (2.54), we see that Equation (8.128) is the inverse of the predicted 
30 covariance Syz-1 given in Equation (8.68). 

at The update step in information form is as follows: 

32 

330 p(zt|Yrt, Urt) = Ne(Zel Maye, Ane) (8.132) 
34 = 

35 Aijt = Age—1 + H; R; 'H; (8.133) 
36 Melt = Nijt—1 + H! R (y: = Diu = d;). (8.134) 
37 

38 8.3.4.2 Filtering: derivation 

39 

40 For the predict step, we first derive the joint distribution over hidden states at t,t — 1: 

Al 

42 P(Zt—-15 2t|Yrt-1, Uae) = P(Ze| 2-1, Ut) P(Zt—-1|Yrt—-1, Ur:t—1) (8.135) 
43 = N.(z, (Q7 (Fizi + Bru + b,),Q;7") (8.136) 
r X Nelzi-1, [M-11 At—a}e—1) (8.137) 
46 = Nelzi 1; zem 1,t|t> Ay L tlt) (8.138) 
AT 
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where 
TOQ-1 
— (maip FQ; (Brus + bi) i 
™M-1,t\t-1 ( Q7! (Beus + bi) (8.139) 
Nya +F! -IF —F! =I 
Ay-1t\t-1 = ( ‘ ee : ae (8.140) 


The information form predicted parameters 7,),_,, Azj:-1 can then be derived using the marginalisa- 
tion formulae in Section 2.2.5.4. 

For the update step, we start with the joint distribution over the hidden state and the observation 
at t: 


P(Zt, YslYrt-1, Ure) = P(Ys| Ze, Ut) P(t |Yrse—1, Ur:t-1) (8.141) 
= Ney; |R (Hz + Du; + di), Ry Ne (zilie Ayi-1) (8.142) 
= Ne(2t; lM, yio Az.ylt) (8.143) 
where 
Tp-1 
_ (mp1 HR; (Dru + di) 
"Nz,y\t ~~ ( R7 (Dius + d;) (8.144) 


Tp-1 _HTR-] 
Aayi = eee i emn, ) (8.145) 


-R 'H; R; 


4 The information form filtered parameters Nt, Ar¢ are then derived using the conditional formulae in 


2 2.2.5.4. 


31 8.3.4.3 Smoothing: algorithm 


The smoothing equations are as follows: 


P(zlyir) = Ne(2t|Miyr, Ar) 
U= Qrt + Agsijr — Atti 
L; = F; Qr `U; 
Ayr = Aij T F/Q;'F, = L,Q; `F 
= Ap +F] Q7 F, — LU, L] 
Mir = Mle + Li (Mayr = M+1\t)+ 


The parameters n, and Ajj; are the filtered values from Equations (8.134) and (8.133) respectively. 


45 Similarly, 7,41), and A;41\, are the predicted parameters from Equations (8.131) and (8.128). The 
46 matrix L; is the information form analog to the backward Kalman gain matrix in Equation (8.111). 
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8.4. INFERENCE FOR NON-LINEAR AND/OR NON-GAUSSIAN SSMS 


8.3.4.4 Smoothing: derivation 


From the generic forwards-filtering backwards-smoothing equation, Equation (8.20), we have 


2141/24) p(z . 

p(2t|y1-7) = plal) f É t+1/2e)P(2e+1/9-7) dZt41 (8.152) 

P(2t+1/Y1:t) 

z : 

TE ee de (8.153) 

P(Z+1/Y1:t) 

Ne( 2411417; Atir) 

= | N.(2t, 24 44g A dz 8.154 
J (Zt, zt 1l l|t t,t+1\t) Neteaalnicie Mea) t+1 ( ) 
= J Nolen ziria eran Aves) deen: (8.155) 


The parameters of the joint filtering predictive distribution, p(z;, 2441|y1), take precisely the 
same form as those in the filtering derivation described in Section 8.3.4.2: 


=I T —1 
_ (Mt _ (Aqet+ FaF -Fi Qa 
Nettie = ( 0 ) > Arti = ( -Q4 Fa Q- ; (8.156) 


We can now update this potential function by subtracting out the filtered information and adding 
in the smoothing information, using the rules for manipulating Gaussian potentials described in 
Section 2.2.7: 


0 0 Net ) 
= i: — = l 8.157 
ettir = Metle as Cass) T: = Ni+ilt ( ) 
and 
0 0 0 0 
Atesir = Atty + (o ae _ i NA (8.158) 
7 T + Fry QF : Fi Qh ) (8.159) 
-Qa F+ Qua + Arar — Arpi 


Applying the information form marginalisation formula Equation (2.46) leads to Equation (8.151) 
and Equation (8.149). 


8.4 Inference for non-linear and/or non-Gaussian SSMs 


In general state space models (Section 29.1), the transition and/or emission probability distributions 
can be arbitrary conditional distributions, and are not restricted to being discrete or Gaussian 
distributions. This makes inference much harder, because the likelihood is not conjugate to the prior, 
and the posterior may not have any convenient analytical form. 


8.4.1 Inference based on discretization 


To illustrate the problem, consider a simple 1d SSM with linear dynamics corrupted by additive 
Student noise: 


a = Z1 + T2(0, 1) (8.160) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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Figure 8.9: (a) Observations and true and estimated state. (b) Marginal distributions for time step t = 20. 
Generated by discretized_ssm_ student. ipynb. 


The observations are also linear, and are also corrupted by additive Student noise: 
Yt = zt + T2 (0, 1) (8.161) 


This robust observation model is useful when there are potential outliers in the observed data, such 
as at time t = 20 in Figure 8.9a. (See also Section 8.8.2 for discussion of robust Kalman filters.) 
Unfortunately the use of a non-Gaussian likelihood means that the resulting posterior can become 
multimodal. In models with low-dimensional latent state spaces (e.g., 1-2d), we can always discretize 
the state space and apply the exact HMM filter and smoother from Section 8.2.4, as proposed in 
[RG17]. We show the results for filtering and smoothing in Figure 8.10a and in Figure 8.10b. We see 
that at t = 20, the filtering distribution, p(z;|y1-29), is bimodal, with a mean that is quite far from the 
true state (see Figure 8.9b for a detailed plot). Such a multi-modal distribution can be approximated 
by a suitably fine discretization. Unfortunately, tackling the inference problem by discretizing the 


29 state space will not scale to higher dimensional problems, due to the curse of dimensionality. In 
30 particular, we know that the HMM filter takes O(K?) operations per time step, if there are K states. 
31 If we have N, dimensions, each discretized into B bins, then we have K = B=, so the approach 


becomes intractable beyond 1d. 


8.4.2 Inference based on Gaussian approximations 


In the sections below, we discuss various deterministic approximate inference schemes which make a 
Gaussian approximation to the posterior. Most of these algorithms assume the SSM has the following 
form: 


ze = f (2-1, us) +N (0, Q:) 


8.162 
yi = h( z4, ut) +N (0, Ri) ( ) 


42 where z; € Rò- is the hidden state, y, € Rò” is the observation, us € RA! are the optional inputs, 
43 f : RN-+Nu — R^- is the dynamics model, and h : RN=+Nu — RN» is the observation model. The 


dynamics and/or observations models can be nonlinear, but we assume the noise terms are Gaussian 


45 and additive. Our presentation is based on [Sar13, Ch. 5]; for more details, see e.g., [Sar13; Fan+17; 
46 Li+17e; Koy+10].) 
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(a) (b) 


Figure 8.10: Discretized posterior of the latent state at each time step. Red cross is the true latent state. Red 
circle is observation. (a) Filtering. (b) Smoothing. Generated by discretized_ssm_ student.ipynb. 


8.5 Inference based on local linearization 


In this section, we extend the Kalman filter and smoother to the case where the system dynamics 
and/or the observation model are nonlinear. (We continue to assume that the noise is additive 
Gaussian, as in Equation (8.162).) The basic idea is to linearize the dynamics and observation 
models about the previous state estimate using a first order Taylor series expansion, and then to 
apply the standard Kalman filter equations from Section 8.3.2. Intuitively we can think of this as 
approximating a stationary non-linear dynamical system with a non-stationary linear dynamical 
system. This approach is called the extended Kalman filter or EKF. 


8.5.1 Taylor series expansion 


Suppose x ~ N (u, ©) and y = g(x), where g : R” > R” is a differentiable and invertible function. 
The pdf for y is given by 


p(y) = | det Jac(g~")(y)| N97 *(y) |, £) (8.163) 


In general this is intractable to compute, so we seek an approximation. 
Suppose x = u + ô, where 6 ~ N(0, =). Then we can form a first order Taylor series expansion of 
the function g as follows: 


g(x) = glu + ô) ~ g(u) + G(u)ð (8.164) 


where G(p) is the Jacobian of g at u: 


[G(u)];y = Oe |z=p (8.165) 
We now derive the induced Gaussian approximation to y = g(x). The mean is given by 
E [y] ~ E[g(u) + G(u)ô] = g(u) + G(u)E [ô] = g(p) (8.166) 
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The covariance is given by 
Cov [y] = E [(g(æ) — E [g(x)])(9(@) — E [g(æ)))"] (8.167) 
~ E [(g(æ) — g(u))(g(£) — 9(#))"] (8.168) 
~ E [(g(u) + G(u)ð — g(u))(g(u) + G(u)ô — g(H))"] (8.169) 
= E [(G(u)ô)(G(u)ô)"] (8.170) 
= G(p)E [aa G(u)" (8.171) 
= G(p) = G(p)" (8.172) 
Algorithm 11: Linearized approximation to a joint Gaussian distribution. 

1 def LinearizedPredict(p, X, g, Q) : 

2m =g(p) 

3 G = Jac(g)(s) 

4 S= GEG! +Q 

5 C= 5G! 

6 Return (m, S, C) 


When deriving the EKF, we need to compute the joint distribution p(x, y) where 
x~ N(u, ©), y =g(£)+q, q ~ N (0,9) (8.173) 


where q is independent of a. We can compute this by defining the augmented function g(x) = [x, g(x)| 
and following the procedure above. The resulting linear approximation to the joint is 


(2) ~((m) (es) earo 


where the parameters are computed using Algorithm 11. We can then condition this joint Gaussian 


on the observed value y to get the posterior. 
It is also possible to derive an approximation for the case of non-additive Gaussian noise, where 
y = g(x, q). See [Sar13, Sec 5.1] for details. 


35 8.5.2 The extended Kalman filter (EKF) 


We now derive the extended Kalman filter for performing approximate inference in the model given by 
Equation (8.162). We first linearize the dynamics model around H_1]+—1 to get an approximation to 


~ the one-step-ahead predictive distribution p(z+|Y1:t-1, U1:t) = N (zt|Hit—1; Xtjt-1). We then linearize 
™ the observation model around y,,_;, and then perform a Gaussian update. (In Section 8.5.2.2, we 
— consider linearizing around a different point that gives better accuracy.) We can write one step of 


— the EKF algorithm as follows, where we use the notation from Section 8.3.2.5: 
(Hijt-1; Ueje-1, —) = LinearizedPredict (p,_ 1), Xt-1ļt-1; fC ur), Qt) (8.175) 
(mi, St, C+) = LinearizedPredict (uy, 1, E¢jr-1, h(-, ut), Re) (8.176) 
(Mele> Lee, h) = GaussCondition(14),_1, Dy\t-1, Mi, St, Cz, Ye) (8.177) 
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8.5. INFERENCE BASED ON LOCAL LINEARIZATION 


See Supplementary Section 8.2.1 for the details of the derivation. 


8.5.2.1 Accuracy 


The EKF is widely used because it is simple and relatively efficient. However, there are two cases 
when the EKF works poorly. The first is when the prior covariance is large. In this case, the prior 
distribution is broad, so we end up sending a lot of probability mass through different parts of 
the function that are far from 4,_;,_;, where the function has been linearized. The other setting 
where the EKF works poorly is when the function is highly nonlinear near the current mean (see 
Figure 8.12a). 

A more accurate approach is to use a second-order Taylor series approximation, known as the 
second order EKF. The resulting updates can still be computed in closed form (see [Sar13, Sec 
5.2] for details). 


8.5.2.2 Iterated EKF 


Algorithm 12: Iterated extended Kalman filter. 
1 def IEKF(f, Q, h, R, y1-7, Hojos Xojo; J) : 

2 foreach t= 1 :T do 

3 Predict step: 

4 (Hijt-1> Xit- —) = LinearizedPredict(4,_1ļ—1; Eii- FO w), Qt) 
5 Update step: 

6 Melt = Htjt—1> Dele = Yi- 

7 foreach j = 1: J do 

8 (mi, St, C+) = LinearizedPredict (py), Sije, hC, ue), Re) 

| (Mele, Dele, 4) = GaussCondition (Hi1, Xijt-1, Mt, St, Ct, Yt) 


Ke) 


Return (Hilts Eaa 


Another way to improve the accuracy of the EKF is by repeatedly re-linearizing the measurement 
model around the current posterior, Hyg, instead of Hyg—1; this is called the iterated EKF [BC93]. 
See Algorithm 12 for the pseudocode. (If we set the number of iterations to J = 1, we recover the 
standard EKF.) Unfortunately the IEKF can diverge. A robust IEKF method, that uses line search 
to perform damped (partial) updates of A , is presented in [SHA15]. 


8.5.3 The extended Kalman smoother (EKS) 


We can extend the EKF to the offline smoothing case, resulting in the extended Kalman smoother, 
also called the extended RTS smoother. We just need to linearize around the filtered mean in 
step 4 of the Kalman smoothing algorithm in Algorithm 9. 

For improved accuracy, we can use the iterated EKS, which relinearizes the model at the previous 
MAP estimate. See Algorithm 13 for the pseudocode. (If we use J = 1 iterations, we recover the 
standard EKS.) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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Noisy obervations from hidden trajectory EKF-filtered estimate of trajectory UKF-filtered estimate of trajectory 


True States 
Observati 


EKF Estimate 2 UKF Estimate 
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Figure 8.11: Illustration of filtering applied to a 2d nonlinear dynamical system. (a) True underlying state and 
observed data. (b) Extended Kalman filter estimate. Generated by ekf_ spiral.ipynb. (c) Unscented Kalman 
filter estimate. Generated by ukf_ spiral.ipynb. 


In [Bel94], they show that IEKS is equivalent to a Gauss-Newton method for computing the MAP 
estimate of the smoothing posterior. Unfortunately the IEKS can diverge. A robust IEKS method, 
that uses line search and Levenberg-Marquadt to update the parameters, is presented in [SS20a]. 


Algorithm 13: Iterated extended Kalman smoother. 

1 def IEKS(f, Q, h, R, y1:T, Mojo, Lojo, J): 

2 (hijo Zijt) = EKF(f,Q,h, T, y1-7, Mojo, Xoo) 

i 

4 foreach j = 1: J do 

foreach t = 1 : T do 
(Mipil Eijs Vet4alt) = LinearizedPredict (uj. , irs Ff (-, ut+1), Qe41) 
beli t41 = (Hije Mesijes Xit Bitit Et t+1lt) belyijr = (Megan) Eeri) 
(Mir Xir) = GaussSoft Condition(bel; t+1jt; belir) 


@ Now 


9 Return (Mir Er) 


2 8.5.4 Examples 


In this section, we give some examples of EKF/EKS. 


— 8.5.4.1 Tracking a point spiraling in 2d 


44 In Section 8.3.1.1, we considered an example of state estimation and tracking of an object moving in 


2d under a linear dynamics model with a linear observation model. However, motion and observation 


46 models are often nonlinear. For example, consider an object that is moving along a curved trajectory, 
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such as this: 
f(z) = (4 + Asin(z2), z2 + Acos(z1)) (8.178) 


where A is the step size. For simplicity, we assume full visibility of the state vector (modulo 
observation noise), so h(z) = z. 

Despite the simplicity of this model, exact inference is intractable. However, we can easily apply 
the EKF. The results are shown in Figure 8.11b. 


8.5.4.2 Neural network training 


In Section 17.5.2, we show how to use the EKF to perform online parameter inference for an MLP 
regression model. 


8.6 Inference based on the unscented transform 


In this section, we replace the local linearization of the model with a different approximation known 
as the unscented transform. When applied to Bayesian filtering, we get the unscented Kalman 
filter (UKF), since it is a version of the EKF that “doesn’t stink” [JU97; JUDW00]. The key idea 
is this: instead of computing a linear approximation to the function and then passing a Gaussian 
through it, we instead pass a deterministically chosen set of points, known as sigma points, through 
the function, and fit a Gaussian to the resulting transformed points; this is known as the unscented 
transform (see Section 8.6.1). Using the unscented transform for the transition and observation 
models gives the the overall method, which is also called a sigma point filter [VDMW03]. 

The main advantage of the UKF over the EKF is that it can be more accurate, and more stable. 
(The EKF can sometimes lead to large errors and divergence of the filter [[X00; VDMW03].) In 
addition, the UKF does not need to compute Jacobians of the observation and dynamics models, so 
it can be applied to non-differentiable models, or ones with hard constraints. However, the UKF can 
be slower, since it requires N, evaluations of the dynamics and observation models. In addition, it 
has 3 hyper-parameters that need to be set. 


8.6.1 The unscented transform 


Algorithm 14: Computing sigma points using unscented transform. 


1 def SigmaPoints(p, X; a, 6,4) : 
2 n = dimensionality of pw 
3 A=a?(n+K)—n 
4 Compute a set of 2n + 1 sigma points: 
Xo = p, Xi = pt vn +A [VE] Xian =H- Vn $4 WVE]: 
5 Compute a set of 2n + 1 weights for the mean and covariance: 


Mm — _ À Oa A 2 Ti — ase — 1 
Wo = pen? Wo = apr t(l-a + 8), wy = w; = nt) 


6 Return (Xo:2n, WO:2n) W:2n) 
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(e) (f) 


Figure 8.12: Illustration of different ways to approximate the distribution induced by a nonlinear transformation 


38 f : R? + R? . (a) Data from the source distribution, D = {æ; ~ p(x)}, with Gaussian approximation 


superimposed. (b) The dots show a Monte Carlo approximation to p(f(x)) derived from D' = {f(ai)}. 
The dotted ellipse is a Gaussian approximation to this target distribution, computed from the empirical 
moments. The solid ellipse is Taylor Transform. (c) Unscented sigma points. (d) Unscented transform. (e) 


-~ Gauss-Hermite points (order 5). (f) GH transform. Adapted from Figures 5.3-5.4 of [Sar13]. Generated by 


— gaussian_transforms.ipynb. 
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8.6. INFERENCE BASED ON THE UNSCENTED TRANSFORM 


Algorithm 15: Unscented approximation to a joint Gaussian distribution. 


1 def UnscentedPredict(p, X, g, Q; a, 8,4) : 
2 (Xo.2n; WO:2n+ WO:2n) — SigmaPoints(p, X; Q, B, k) 
3 Vi = gX), i=0: 2n 
2n m 
4 m= ) i= Wi Yi 
5 S= Eh ș(Vi — by (Yi — My)" +2, 
2n c 
6 C= Dio W(X — u) (Vi — uy) 
7 Return (m, S, C) 


Suppose we have two random variables x ~ N (u, ©) and y = g(x), where g : R” > R™. The 
unscented transform forms a Gaussian approximation to p(y) using the following process. First we 
compute a set of 2n + 1 sigma points, V;, and corresponding weights, w?” and wf, using Algorithm 14, 
for i = 0 : 2n. (The notation M.; means the i'th column of matrix M, V© is the matrix square root, 


so VEVE = X.) Next we propagate the sigma points through the nonlinear function to get the 
following 2n + 1 outputs: 


Finally we estimate the mean and covariance of the resulting set of points: 
2n 

[g(x)| ~ m = So wP y; (8.180) 
25 i=0 
26 2n 
27. Cov[g(æ)] = S =X w$ Xi- m) (Vi - m)" (8.181) 
28 i=0 
23 Now suppose we want to approximate the joint distribution p(æ, y), where y = g(x) + e, and 
3 e~ N(0,Q). By defining the augmented function g(a) = (a,g(a)), and applying the above 
31 procedure (and adding extra noise), we get 
32 
33 x H xu C 
a Ger llah le s)) = 
35 where the parameters are computed using Algorithm 15. 
ae The sigma points and their weights depend on three hyper-parameters, a, 8 and «, which determine 
3T the spread of the sigma points around the mean. A typical recommended setting for these is a = 1078, 
= «= 1, = 2 [Bit16]. 
39 
40 
A 8.6.1.1 Accuracy 
42 In Figure 8.12(a-b), we show the linearized Taylor transform discussed in Section 8.5.1 applied to a 
43 nonlinear function. In Figure 8.12(c-d), we show the corresponding unscented transform, which we 
44 can see is more accurate. In fact, the unscented transform is a third-order method in the sense that 
45 the mean of y is exact for polynomials up to order 3. However the covariance is only exact for linear 
46 functions (first order polynomials). 
47 
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8.6.1.2 Gauss-Hermite integration 


We can compute even more accurate approximations using other forms of numerical integration. For 
example, we can use Gauss-Hermite integration of order p, which will be exact for polynomials of 
order up to 2p — 1. See [Sar13, Sec 6.3] for details, and Figure 8.12(e-f) for an illustration. However, 
this comes at a price: the number of sigma points is now p”, whereas it is only 2n + 1 for the 
unscented transform. 


8.6.2 The unscented Kalman filter (UKF) 


The UKF applies the unscented transform twice to Equation (8.162), once to approximate passing 
through the system model f, and once to approximate passing through the measurement model h. 
By analogy to Section 8.3.2.5, we can derive the UKF algorithm as follows: 


(Myje—-1, Beje-1, —) = UnscentedPredict(H;—1jt-1; 2 e—-1t-1, FC ur), Qe) (8.183) 
(mi, St, C+) = UnscentedPredict (py), 1, Dijz—1, hC, ur), Re) (8.184) 
(Hilts Xit 4) = GaussCondition(14),_1, Xit, Mi, St, Ct, Ye) (8.185) 


See [Sar13, p86] for more details. 


8.6.3 The unscented Kalman smoother (UKS) 


The unscented Kalman smoother, also called the unscented RTS smoother [Sar08], is a 
simple modification of the usual Kalman smoothing method, where we approximate the nonlinearity 
by the unscented transform. By analogy to Section 8.3.3, the update is as follows: 

8.186 
8.187 
8.188 
8.189 


(Hiij Vetijes Dt t+11t) = UnscentedPredict (Ht, Yije, f(-, we+1), Qe+1) 
belt t+1ļt = (Hije Meijer Delt, Vepaje, Ve,traje) 


(8.186) 
(8.187) 
beli+ıT = (Higa Diyir) ( ) 
(8.189) 


(Hir Xr) = GaussSoftCondition(bel; 441), belar) 


See [Sar13, p148] for more details. 


35 8.6.4 Examples 


In this section we give some examples of the UKF. 


8.6.4.1 Tracking a point spiraling in 2d 


= Let us revisit the 2d nonlinear tracking problem from Section 8.5.4.1. In Figure 8.11c, we see that 
= the UKF algorithm (with a = 1, 6 = 0, k = 2) works well on this problem. 


— 8.6.4.2  COVID-19 risk score estimation 


45 During the COVID-19 pandemic in 2020, many governments created contact tracing apps, in which 


the distance between people was estimated (anonymously) by measuring bluetooth signal strength 
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8.7. GENERAL GAUSSIAN FILTERING 


between mobile phones. This can be done using the UKF, as explained in [Lov-+20]. Indeed, this 
approach formed the basis of the UK COVID-19 risk score estimation app [BCH20].° 


8.7 General Gaussian filtering 


This section was co-authored with Peter Chang. 


In this section, we discuss a simple unified framework, known as the general Gaussian filter or 
GGF [IX00; Wu-+06], which includes EKF, UKF, and various other algorithms. Our presentation is 
based on [Sar13, Ch. 6]. 


8.7.1 Gaussian moment matching 


In this section, we consider a single time step of inference, and temporarily drop the time indices, and 
the conditioning on past information, for notational brevity. Furthermore, we will use the shorthand 


[sw = ie af g(a1,...,2n)dx1---drn (8.190) 


Let p(x) = N(x|u, 4) and p(y|x) = V(y|g(a), Q) for some function g. Let p(x, y) = p(x)p(y|x) 
be the exact joint distribution. The best Gaussian approximation to the joint is given by 


As we explain in Section 5.1.3.4, this can be obtained by moment matching, i.e., 


q(x, y) =N ((5) | a , (& sa) (8.192) 


where 
maj g(x) N (z|u, 2) (8.193) 
S= | (ale) m\(a(e) - mN ælu.) = | g(æ)g(e] N (ælu, 2) - mm" (8.194) 
C= |œ- ugle) - mN (ælu, 2) = | zg(e) N (ælu, 2) — um" (8.195) 


We can compute these integrals using the methods we discuss in Section 8.7.3. 


5. Once the distance has been estimated, it needs to be combined with other signals, such as contact duration and 
infectiousness level of the index case, to estimate the risk of transmission. For an ML based approach to this problem, 
see [MKS21]. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e Iw IN Ie 


= 
= 


Is Ie IR ls lẹ le le Is | 


N 
(= 


N 
a 


IS IS 18 IS | 


368 


8.7.2 Application to filtering 


Let us now apply this method in the filtering context. We assume the prior has the form 
P(2e-1|Y1:t-1) = N(My-aje-1, e-1,t-1)- In the prediction step, we compute a Gaussian approx- 
imation to p(z+-1, 24|y1-4-1), from which we can compute p(z:|y1..-1) by computing the following 
moments: 


bya f Ff (Zt-1)N (2-1 | Me—aje—15 Dt—1Je-1) 
je (8.196) 


Xit- = f (Ff (Zt-1) — Meijer) (F (2-1) = Hije) N (z alee ilt- St—1|t 1) +Q: 


In the update step, we compute a Gaussian approximation to p(z, y:|yi-4-1), by computing the 
following moments: 


ae i h(z) N (el tee a» Eae) (8.197) 
S, = 1 (h(i) — H) (hlz) — Ge)" N (ztl Hit- Deje—-1) + Re (8.198) 
C= (zi — bayer) (aCe) — WN (zilte Beyer) (8.199) 


We then condition this Gaussian on seeing y+ using Bayes rule, which gives the usual update step: 


K; = C,S;" (8.200) 
Hiji = Muye—1 + Kilu — De) (8.201) 
Die = Yit- — K:S;K] (8.202) 


— 8.7.3 Computing the moments 


32 One way to compute the integrals in Section 8.7.1 is to linearize g around the predicted mean, as in 
33 the Taylor approximation of Section 8.5.1. We denote this by 


(m,S,C) = LinearizedPredict(p, ©, g, Q) (8.203) 


Using this inside the GGF is equivalent to the EKF in Section 8.5.2. However, this approach can 


3g lead to large errors and sometimes divergence of the filter [IX00; VDMW03]. 


A more stable method is to use numerical integration®, which can be writen as 


(8.204) 


r= 
S 
8 
? 
iM 
& 
cod 
s 
8 
soe 


k=1 


44 for a suitable set of evaluation points z} (sometimes called sigma points) and weights w*. 


46 6. One-dimensional integrals are called quadratures, and multi-dimensional integrals are called cubatures. 
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One way to compute the sigma points is to use the unscented transform, as in Algorithm 15. We 
denote this by 


(m,S,C) = UnscentedPredict(p, £, g, Q) (8.205) 


Using this inside the GGF is equivalent to the UKF in Section 8.6.2. 

Alternatively, we can use spherical cubature integration, which gives rise to the cubature 
Kalman filter or CKF [AH09]. This turns out (see [Sar13, p110]) to be a special case of the UKF, 
with 2n, + 1 sigma points, and hyper-parameters values of a = 1 and 6 = 0 (with «x left free). 

A more accurate approximation uses Gauss-Hermite integration, which allows the user to 
select more sigma points, as we discussed in Section 8.6.1.2. This gives rise to the Gauss-Hermite 
Kalman filter or GHKF [IX00], also known as the quadrature Kalman filter or QKF [AHEO7]. 

We can also approximate the integrals with Monte Carlo. Note, however, that this is not the same 
as particle filtering (Section 13.2), which approximates the conditional p(z;|y1..) rather than the 
joint p(z, YtlYı:+—1) (see Section 8.9.1 for discussion of this difference). 


8.7.4 Statistical linear regression 


In this section, we provide an alternative (but equivalent) perspective on moment matching, known as 
statistical linear regression or SLR [LBS01; AHEO7]. The key observation is that approximating 
the true distributions p(z:—1, Z:|Y14—1) and p(zt, y¥:|Y1z-1) as joint Gaussians is equivalent to 
linearizing the corresponding transition and measurement functions. This follows from the following 
lemma [Kam-+22]: 


Lemma 8.7.1. A random variable (x,y) is jointly Gaussian iff there exist matrices A € RNu*%=, 
Q e RXv*Nv and a vector b e R» such that y = Ax +b + €, where e ~ N(0,Q). Furthermore, 
these parameters are given by 


A-c'y! (8.206a) 
b=m—-Ap (8.206b) 
Q =S- AXA! (8.206c) 


Proof. Since (x, y) is jointly Gaussian, we have 


(2) ~((m) (es) sam 


On the other hand, if y= Ax + b + N(0,Q), we have 


(5) Di (Ge p i o N a) (8.208) 


These distributions only match if we satisfy the constraints in Equation (8.206). 
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Another way to derive this result is as follows. Suppose we choose the parameters of the linear- 
Gaussian approximation 0 = (A, b,Q) so that the resulting joint is as close as possible to the true 
joint: 


Ô = argmin Dia, (pæ y) || q(@.y8)) = argmax | p(æ, y) log ale. 6) (8.209) 
xy 

This can be solved by moment matching, as we discussed in Section 8.7.1. However, this is also 

equivalent to minimizing the mean squared error of the approximation: 


(A, b) = argminE [(g(x) — Ax — b)'(g(a) — Aw — b)| (8.210) 


, 


where Q is the covariance of the resulting errors: 


Q=E (92) — Aa — b)(g(x) — Aa — 6)" (8.211) 


One can show that the resulting optimal parameters are given by Equation (8.206). 


8.7.5 Iterated posterior linearization filter 


In this section, we propose a GGF version of the iterated EKF which we discussed in Section 8.5.2.2. 
The resulting method is known as the iterated posterior linearization filter or IPLF [GF +15]. 
The basic idea is this: when computing the moments for the measurement update step, use the 
posterior instead of the prior, so that the observation model is linearized near the location where the 
posterior has most of its mass. Of course, this approach cannot be implemented as stated, because it 
requires knowledge of the posterior to approxmate the posterior. However, we can approximate it 
by an iterative method, as shown in Algorithm 16. If we set the nunber of iterations to J = 1, we 


2 recover the “prior linearization filter”, which is equivalent to the UKF (Section 8.6.2) if we use 
=“ the unscented transform to compute the moments. 


Unfortunately the IPLF can diverge. A more robust version, that uses line search to perform 


= damped (partial) updates, is presented in [Rai+18b]. 


— 8.7.6 Iterated posterior linearization smoother 


33 In [GFSS17] they propose a method called the iterated posterior linearization smoother or 
34 IPLS, which extends the IPLF method of Section 8.7.5 to the offline setting. The basic idea is 
35 to linearize the dynamics and observation models by computing expectations with respect to the 
36 previous smoothed posterior, and to iterate this process until convergence. See Algorithm 17 for 
37 the pseudocode. If we use J = 1 iterations, and if we use the unscented transform to compute the 
38 moments, then we recover the UKS algorithm from Section 8.6.3. 


Unfortunately the IPLS can diverge. A more robust version, that uses line search and Levenberg- 


40 Marquadt to update the parameters, is presented in [Lin+21c]. 


The IPLS is very similar to the iterated EKS, which we discussed in Section 8.5.3. The difference 


42 is that the IEKS uses a first order Taylor series approximation of the measurement and dynamics 
43 functions at the posterior mean from the previous iteration, rather than using a general Gaussian 
44 moment matching scheme. It thus ignores uncertainty due to the linearization process. Indeed, 
45 [Bel94| showed that IEKS is equivalent to a Gauss-Newton method for computing the MAP estimate 
46 of the smoothing posterior. 
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Algorithm 16: Iterated posterior linearization filter. 

1 def IPLF-Unscented(f,Q,h,R, y1:7, Mojo, Lolo, J) : 

2 foreach t= 1:7 do 

3 Predict step: 

4 (Melt Xit- -)= UnscentedPredict(p,_ 14-1, ae ee f (-, ut), Qe) 
5 Update step: 

6 Melt = Htjt—1> Xij = Veje—1 

7 foreach j = 1: J do 

8 (mz, St, C+) = UnscentedPredict(p,, X, h(-, uz), R) 

9 i (Hiji Eie h) = GaussCondition (Hi1; Dye—1, Mi, St, Ci, Ye) 
10 Return (Hiji Eei Di l 


Algorithm 17: Iterated posterior linearization smoother. 


1 def IPLS-Unscented(f, Q, h, R, yi-7, Mojo, Xojo; J): 

2 (Hije Zijt) = UnscentedFilter(f, Q, h, R, Y1:T, Mojo, £ojo) 
3 (Mir = Hilt, Deir = Deje)t=1 

4 foreach j =1: J do 

foreach t = 1 : T do 


(Hiijt DSe¢ijes Xt t+1lt) = UnscentedPredict (pf , Dip, (-, ut+1), Qipa) 


beli t41 = (Mites Mt+i|t> Delt» Deities Ditti) belir = (Hi+ilts Ziri) 
(Hir Xir) = GaussSoftCondition(bel; 441)2, belir) 


onto a 


9 Return (uir: Er) 


8.7.7 Beyond additive Gaussian noise 


In this section, we extend GGF to handle the case of non-Gaussian / non-additive noise, following 
[TGFS18]. 


8.7.7.1 Conditional moments Gaussian filter 


We assume p(x) = N(z|uy,®&x). We allow p(y|x) to be any kind of distribution, provided it 
has finite variance. We only require that we can approximate the first and second conditional 
moments: 


my (x) = E [y|z] , Vy (x) = Cov [y|æ] (8.212) 


For example, if p(y|x) = Poisson(y|ce”), we have 
my (a) = Vy (x) = ce” (8.213) 
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As usual we will approximate the conditional distribution by a linear Gaussian model of the form 
p(y|x) =N (y| Ax + b,Q). We can then easily compute the joint p(x, y), the marginal p(y), and the 
posterior p(a|y). 

The optimal parameters, in the sense of minimizing the mean squared norm of the residual, 
\|y — Aw — b||, are given by SLR Equation (8.206), where the moments are computed using 


m = E [Y] = E [E [Y|X]] = / p(x)my (x) (8.214) 
S = yY [Y] =E[V[Y|a]] + V [E[Y|X]] (8.215) 

= | (æ) Vy(æ)+ | pæ)(my (a) -mmy (a2) - m) (8.216) 

C = Cov [X, Y] = Cov [X, E [Y |X]] = J plæ)(x — py )(my (x) — m) (8.217) 


16 where expectations are taken wrt p(x) = N (æ|ux,®x). This is called generalized statistical 
17 linear regression. 


In general, we cannot compute these expectations exactly, but we can use the approximate methods 


19 discussed in Section 8.7.3. For example, we can make a Taylor series approximation to the conditional 


moments, and then we can compute the unconditional moments analytically. In more detail, we first 


21 compute 
my (æ) ~ my (u) + Ymy (u)(# — m) (8.218) 
= —— 
Vy (a) ~ Vy (u) (8.219) 


27 We can then compute the unconditional moments using Algorithm 18. 


Algorithm 18: Linearized approximation to a non-Gaussian distribution. 


def CondLinearizedPredict(u.,Ux,my(), Vy()) : 
m = my (Ly) 

J = Jac(my)(Hx) 

S =JXxJ' + Vy(px) 

C= xJ' 

Return (m,S, C) 


ant fF WON 


Once we have “Gaussianized” the likelihood, we can then use generalized Gaussian filtering in the 


== usual way. We call this the conditional moments Gaussian filter or CMGF. 


— 8.7.7.2 Iterated CMGF 


43 As in the IPLF method of Section 8.7.5, we can get improved accuracy by iterating the computation 


of the moments, using the most recent posterior estimate for the expectation. We call this the 


45 iterated CMGF. We give the pseudocode in Algorithm 19 for the special case where we use Taylor 


series linearization rather than sigma point methods, and where we assume the dynamics has additive 
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8.7. GENERAL GAUSSIAN FILTERING 


Gaussian noise. If the observation model is also additive Gaussian (so the observation conditional 
covariance is a constant), the resulting algorithm is equivalent to the IEKF in Section 8.5.2.2. 


Algorithm 19: Iterated conditional moments Gaussian filter using Taylor series linearization. 


1 def ICMGF-Extended(f,Q, my (), Vy (), Y1:T, Hojo» ojo; J) : 

2 foreach t=1:7 do 

3 Predict step: 

4 (Hijt-1> Xit- —) = LinearizedPredict(p44_ 1,1, Xii- f, Qt) 

5 Update step: 

6 Mele = Hijt-1> Dele = Veje—1 

7 foreach j = 1: J do 

8 (mi, St, Cz) = CondLinearizedPredict (14);, Dit, my, Vy) 

9 | (Melts Due, Ce) = GaussCondition(p,),1, ijt- Mt, St, Ct, yr) 
10 Return (Mees Eae) 


8.7.8 Other extensions 


Various extensions of the (conditional) IPLF and IPLS have been proposed. 

In [HPR19] they extend IPLS to belief propagation in Forney factor graphs (Section 4.6.1.2), 
which enables the method to be applied to a large class of graphical models beyond SSMs. In 
particular, they give a general linearization formulation (including explicit message update rules) for 
nonlinear approximate Gaussian BP (Section 9.3.3) where the linearization can be Jacobian-based 
(“EKF-style”), statistical (moment matching / quadrature filtering / sigma points), or anything else. 
They also show how any such linearization method can benefit from iterations. 

In [Kam-+22], they present a method based on approximate expectation propagation (Section 10.7), 
that is very similar to IPLS, except that the distributions that are used to compute the SLR terms, 
needed to compute the Gaussian messages, are different. In particular, rather than using the smoothed 
posterior from the last iteration, it uses the “cavity” distributions, which is the current posterior 
minus the incoming message that was sent at the last iteration, similar to Section 8.3.4.4. The 
advantage of this is that the outgoing message does not double count the evidence. The disadvantage 
is that this may be numerically unstable. 

In [TGFS18], they extend conditional IPLF to the smoothing (offline) setting to get the conditional 
IPLS algorithm. In [Hos+20b] they use conditional IPLF as a proposal distribution inside of a 
particle filtering algorithm (Section 13.2). 

In [GFTS19], they use conditional posterior linearization to fit a Gaussian process with a Bernoulli 
likelihood. In [WSS21], they propose a variety of “Bayes-Newton” methods for approximately 
computing Gaussian posteriors to probabilistic models with nonlinear and/or non-Gaussian likelihoods. 
This generalizes all of the above methods, and can be applied to SSMs and GPs. 
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8.8 Other variants of the Kalman filter 


In this section, we briefly mention some other variants of Kalman filtering. For a more extensive 
review, see [Li+17e]. 


8.8.1 Ensemble Kalman filter 


The ensemble Kalman filter (EnKF) is a technique developed in the geoscience (meteorology) 
community to perform approximate online inference in large nonlinear systems. In particular, it 
is mostly used for problems where the hidden state represents an unknown physical quantity (e.g., 
temperature and pressure) at each point on a spatial grid, and the measurements are sparse and 
spatially localized. Combining this information over space and time is called data assimilation. 

The canonical reference is [Eve09], but a more accessible tutorial (using the same Bayesian signal 
processing approach we adopt in this chapter) is in [Rot+17]. 

The key idea is to represent the belief state p(z:|y1..) by a finite number of samples Zi = {z 


S . 
tlt ° 
sS 


s = 1 : Ns}, where each Zi E R=. In contrast to particle filtering (Section 13.2), the samples are 
updated in a manner that closely resembles the Kalman filter, so there is no importance sampling 
or resampling step. The downside is that the posterior does not converge to the true Bayesian 
posterior even as N, — oo [LGMT11], except in the linear-Gaussian case. However, sometimes the 
performance of EnKF can be better for small number of samples (although this depends of course on 


the PF proposal distribution). 
The posterior mean and covariance can be derived from the ensemble of samples as follows: 
ie 1 
Halt = N. DET = ytet (8.220) 
S g=1 S 
1 Š 1 
Xij = N,—1 2i E Aie) (2p = Pae) a N, 1 tit Ba (8.221) 


— where Zit = Zijt = Aet. 


We update the samples as follows. For the time update, we first draw N, system noise variables 


= që ~ N (0, Q+), and then we pass these, and the previous state estimate, through the dynamics 
~~ model to get the one-step-ahead state predictions, Zilt—1 = f (zf it-1 4); from which we get 


Zijt-1 = {zie} which has size N, x Ns. Next we draw N, observation noise variables rë ~ N (0, R+), 
and use them to compute the one-step-ahead observation predictions, Yit- = h(zfi1> rf) and 


38 Y¢jt-1 = {uih which has size Ny x Ns. Finally we compute the measurement update using 


Zijt = Zaye- + Kill" — Yia) (8.222) 


42 which is the analog of Equation (8.74). 


We now discuss how to compute K;, which is the analog of the Kalman gain matrix in Equa- 


44 tion (8.72). First note that we can write the exact Kalman gain matrix (in the linear-Gaussian case) 
45 as K; = Xe- H' S7! = CS where S; is the covariance of the measurements, and C; is the 
46 cross-covariance between the state and output predictions. In the EnKF, we approximate S; and C; 
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empirically as follows. First we compute the deviations from predictions: 
is 1 x 1 
Zi\t-1 = Ze (1 = ag hl) Yujt-1 = Yie- = rae (8.223) 


Then we compute the sample covariance matrices 


= 1 1 


Š= 4 TÝ it-1 Yie Č = T T Žit-1 Yie- (8.224) 
Finally we compute 
K, = C,8;1 (8.225) 


We now compare the computational complexity to the KF algorithm. Recall that N, is the number 
of latent dimensions, Ny is the number of observed dimensions, and N, is the number of samples. 
We will assume N; > N, > Ny, as occurs in most geospatial problems. The EnKF time update takes 
O(N?N,) operations, and the measurement update takes O(N.N,N;), By contrast, in the KF, the 
time update takes O(N?) operations, and the measurement update takes O(N?N,). So we see that 
the EnKF is faster for high dimensional state spaces, because it uses a low-rank approximation to 
the posterior covariance. 

Unfortunately, if N, is too small, the EnKF can be become overconfident, and the filter can diverge. 
Various heuristics (e.g., covariance inflation) have been proposed to fix this. However, most of these 
methods are ad-hoc. A variety of more well-principled solutions have also been proposed, see e.g., 
[FK13b; Reil3]. 


8.8.2 Robust Kalman filters 


In practice we often have noise that is non-Gaussian. A common example is when we have clutter, or 
outliers, in the observation model, or sudden changes in the process model. In this case, we might 
use the Laplace distribution [Ara+09] or the Student-t distribution [Aral0; ROG13; Ara+17] as 
noise models. 

[Hua+17b] proposes a variational Bayes (Section 10.2.3) approach, that allows the dynamical 
prior and the observation model to both be (linear) Student distributions, but where the posterior is 
approximated at each step using a Gaussian, conditional on the noise scale matrix, which is modeled 
using an inverse Wishart distribution. An extension of this, to handle mixture distributions, can be 
found in [Hua+19]. 


8.8.3 Dual EKF 


In this section, we briefly discuss one approach to estimating the parameters of an SSM. In an offline 
setting, we can use EM, SGD or Bayesian inference to compute an approximation to p(@|y1-r) (see 
Section 29.8). In the online setting, we want to compute p(@;|y1.,). We can do this by adding the 
parameters to the state space, possibly with an artificial dynamics, p(0;|0;-1) = N(@;|01—1, €I), 
and then performing joint inference of states and parameters. The latent variables at each step 
now contain the latent states, z,, and the latent parameters, 0;. One approach to performing 
approximating inference in such a model is to use the dual EKF, in which one EKF performs state 
estimation and the other EKF performs parameter estimation [WNO1]. 
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Figure 8.13: Illustration of the predict-update-project cycle of assumed density filtering. qe € Q is a tractable 
distribution, whereas we may have Pit- Z Q and pi ¢ Q. 


8.9 Assumed density filtering 


In this section, we discuss assumed density filtering or ADF [May79]. In this approach, we 
assume the posterior has a specific form (e.g., a Gaussian). At each step, we update the previous 
posterior with the new likelihood; the result will often not have the desired form (e.g., will no longer 
be Gaussian), so we project it to the closest approximating distribution of the required type. 

In more detail, we assume (by induction) that our prior q@—1(2-1) © p(Zt-1|yi4-1) satisfies 
g@—1 E€ Q, where Q is a family of tractable distributions. We can update the prior with the new 
measurement to get the approximate posterior as follows. First we compute the one-step-ahead 


predictive distribution 
Peje—1(2t|Y1t-1) = [pleas 1)Qe—1 (24-1) dz4_1 (8.226) 
31 Then we update this prior with the likelihood for step t to get the posterior 
pe(Ze Yr) = 7 Plu lee) z) (8.227) 
where 
Zi = | otula- eed (8.228) 


= jis the normalization constant. Unfortunately, we often find that the resulting posterior is no longer in 


our tractable family, p(z;) ¢ Q. So after Bayesian updating we seek the best tractable approximation 
by computing 


a (Zt|Yi:t) = argmin Dr (pi(zilyr:) || a(2)) (8.229) 
qe 


44 This minimizes the Kullback-Leibler divergence from the approximation q(z+) to the “exact” posterior 
45 pi(zt), and can be thought of as projecting p onto the space of tractable distributions. Thus the 
46 overall algorithm consists of three steps — predict, update, and project — as sketched in Figure 8.13. 
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Approximate 


Conditional 
Expectation 


Figure 8.14: A taxonomy of filtering algorithms. Adapted from Figure 2 of [Wiit+16]. 


Computing min, Dxz (p || q) is known as moment projection, since the optimal q should have 
the same moments as p (see Section 5.1.3.4). So in the Gaussian case, we just need to set the 
mean and covariance of q, so they are the same as the mean and covariance of p. We will give 
some examples of this below. By contrast, computing ming Dxz (q || p), as in variational inference 
(Section 10.1), is known as information projection, and will result in mode seeking behavior (see 
Section 5.1.3.3), rather than trying to capture overall moments. 


8.9.1 Connection with Gaussian filtering 


When Q is the set of Gaussian distributions, there is a close connection between ADF and Gaussian 
filtering, which we discussed in Section 8.7. GF corresponds to solving the following optimization 
problem 


dijt—1 (2t, Yt) = argmin Dpi (p(z, Ye|Yrt—1) || a(z, YelYre—1)) (8.230) 
qEQ 


which can be solved by moment matching (see Section 8.7). We then condition this joint distribution 
on the event y; = yz, where y is the unknown random variable and y; is its observed value. This gives 
pi(Z|Y14), which is easy to compute, due to the Gaussian assumption. By contrast, in Gaussian ADF, 
we first compute the (locally) exact posterior p:(zt|Y1:t), and then approximate it with q(z:|y1-4) by 
projecting into Q. Thus ADF approximates the conditional p;(z:|y1), whereas GF approximates 
the joint pyjz—1(2:, Y:|Y14—1), from which we derive p;,(2;|y1:.) by conditioning. 

ADF is more accurate than GF, since it directy approximates the posterior, but it is more 
computationally demanding, for reasons explained in [Wiit+16]. However, in [Kam+22] they propose 
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bi > by >| Filter 1 ~ b 
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Merge 
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Figure 8.15: ADF for a switching linear dynamical system with 2 discrete states. (a) GPB2 method. (b) 
IMM method. 


an approximate form of expectation propagation (which is a generalization of ADF) in which the 
messages are computed using the same local joint Gaussian approximation as used in Gaussian 
filtering. See Figure 8.14 for a summary of how these different methods relate. 


8.9.2 ADF for SLDS (Gaussian sum filter) 


In this section, we apply ADF to inference in switching linear dynamical systems (SLDS, Section 29.9), 
which are a combination of HMM and LDS models. The resulting method is known as the Gaussian 
sum filter (see e.g., [Cro+11; Wil+17]). 

A Gaussian sum filter approximates the belief state at each step by a mixture of K Gaussians. 
This can be implemented by running K Kalman filters in parallel. This is particularly well suited 
to switching SSMs. We now describe one version of this algorithm, known as the “second order 
generalized pseudo Bayes filter” (GPB2) [BSF88]. We assume that the prior belief state b;_, is 
a mixture of K Gaussians, one per discrete state: 


i 4 P(Zt-1, M1 = WYrt-1) = mate aN (2-1 Mappa Di aie) (8.231) 


a. where i € {1,..., K}. We then pass this through the K different linear models to get 


be = p(2t, met-1 1, Mt jlYr:t) Ti, (zilei =H) (8.232) 


tit titi: where A;; = p(m: = j|m4_1 = i). Finally, for each value of j, we collapse 


40 the K Gaussian mixtures down to a single mixture to give 


bt = p(z, m = jlyit) = Ti (zile g Ej) (8.233) 


See Figure 8.15a for a sketch. 
The optimal way to approximate a mixture of Gaussians with a single Gaussian is given by 


46 q = arg ming Dri (q || p), where p(z) = Yp TEN (z|u*, 5") and q(z) = N (z|, £). This can be 
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8.9. ASSUMED DENSITY FILTERING 


solved by moment matching, that is, 


p=E|z] =$ r" p" (8.234) 
k 
£ = Cov [2] = Sox (=* + (uk — pr) (uk uy") (8.235) 
k 


In the graphical model literature, this is called weak marginalization [Lau92], since it preserves 
the first two moments. Applying these equations to our model, we can go from b% to b? as follows 
(where we drop the t subscript for brevity): 


m=) nt (8.236) 


jli ne 2 
7 Lyr ee 
Te Oe rp (8.238) 
sy = Xr! (= + (u — pt) (ud — ,)") (8.239) 


i 


This algorithm requires running K? filters at each step. A cheaper alternative, known as interac- 
tive multiple models or IMM [BSF88], can be obtained by first collapsing the prior to a single 
Gaussian (by moment matching), and then updating it using K different Kalman filters, one per 
value of m¿. See Figure 8.15b for a sketch. 


8.9.3 ADF for online logistic regression 


In this section we discuss the application of ADF to online Bayesian parameter inference for a binary 
logistic regression model, based on [Zoe07]. The overall approach is similar to the online linear 
regression case (discussed in Section 29.7.2), but approximates the posterior after each update step, 
which is necessary since the likelihood is not conjugate to the prior. 

We assume our model has the following form: 


P(yilee, wi) = Ber(ys|o(a} wr)) (8.240) 
p(w:|wr_-1) = N(w;|wr-1, Q) (8.241) 


where Q is the covariance of the process noise, which allows the parameters to change slowly over 
time. We will assume Q = eI; we can also set € = 0, as in the recursive least squares method 
(Section 29.7.2), if we believe the parameters will not change. See Figure 8.16 for an illustration of 
the model. 

As our approximating family, we will use diagonal Gaussians, for computational efficiency. Thus 
the prior is the posterior from the previous timestep, and has the form 


p(we—a[Dre—1) © er (we—1) = [PM (wd de aya aye) (8.242) 
j 
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Ti—1 Tt 


Figure 8.16: A dynamic logistic regression model. w; are the regression weights at time t, and m = wi x. 
Compare to Figure 29.244. 


where May +1 and TE are the posterior mean and variance for parameter j given past data. 


1ļt—1 
Now we discuss how to aie this prior. 

First we compute the one-step-ahead predictive density p;;_1(w;) using the standard linear- 
Gaussian update, i.e., fy, = My—aje—-1 ANd Tyt-1 = Tr-1\t-1 +Q, where we can set Q = OI if there 
is no drift. 

Now we concentrate on the measurement update step. Define the scalar sum (corresponding to the 


logits, if we are using binary classification) as m = wl æ,. If pyp—1(we) = IL N (wf lav TD 


26 then we can compute the 1d prior predictive distribution for 7; as follows: 


P(m|Pist—1, £t) © Peje-1(m) = N (me |Mejt—1, Veje-1) (8.243) 
Maja = > teiti (8.244) 

J 
Yea = > WE STH (8.245) 

J 


The posterior for the 1d m is given by 


P(h| Dit) © pem) =N (mm, ve) (8.246) 
me = f mgp vmo- dm (8.247) 
v = fi ZPluel Pere) -m (8.248) 
Z = J rlucindea ala (8.249) 


45 where p(y:|7%) = Ber(yz|7). These integrals are one dimensional, and so can be efficiently computed 
46 using Gaussian quadrature, as explained in [Zoe07; K BOO]. 
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"=" Wo batch (Laplace) 1.50 
—ł— Wo online (adf) 15 
Lod a COCcCcCcCCoCCCeCeOOE i 1.25 
= 00 
u 05 2 gt 
z 5 05 a 
© = 20.75 
w 00 S 00 g 
z : 0.50 
—0.5 -0.5 =... Wy batch (Laplace) 0.25 =- w2 batch (Laplace) 
—t wi online (adf) — wp online (adf) 
-1.0 -1.0 0.00 
10 20 30 40 50 60 10 20 30 40 50 60 10 20 30 40 50 60 
number samples number samples number samples 
(a) (b) (c) 


Figure 8.17: Bayesian inference applied to a 2d binary logistic regression problem, p(y = 1|x) = o(wo + 
wızı + w2x2). We show the training data and the posterior predictive produced by different methods. (a) 
Offline MCMC approximation. (b) Offline Laplace approximation. (c) Online ADF approximation at the 
final step of inference. Generated by adf_logistic_regression_demo.ipynb. 


Having inferred p;(n,), we need to compute p;(w|7). This can be done as follows. Define ôm as 
the change in the mean and 6, as the change in the variance: 


Mt = Mijt-1 + Om, Vt = Vit-1 + v (8.250) 


Using the fact that p(m|w) = N(m|w'm,0) is a linear Gaussian system, with prior p(w) = 
pP(W|Hijt—1; Ttjt-1) and “soft evidence” p(m) = N (m+, vt), we can derive the posterior for p(w|D;) as 
follows: 


pilwi) = N (wil hije Tae) (8.251) 
Hije = Hit- + aidm (8.252) 
Tele = Tije + a; ôv (8.253) 

riri 
a; ê tet (8.254) 


boy (x4)? T Tiii 


Thus we see that the parameters which correspond to inputs i with larger magnitude (big |x}|) or 
larger uncertainty (big Ti +1) get updated most, due to a large a; factor, which makes intuitive sense. 

As an example, we consider a 2d binary classification problem. We sequentially compute the 
posterior using the ADF, and compare to the offline estimate computed using a Laplace approximation. 
In Figure 8.17 we plot the posterior marginals over the 3 parameters as a function of “time” (i.e., after 
conditioning on each training example one). We see that we converge to the offline MAP estimate. In 
Figure 8.18, we show the results of performing sequential Bayesian updating in a different ordering 
of the data. We still converge to approximate the same answer. In Figure 8.19, we see that the 
resulting posterior predictive distributions from the Laplace estimate and ADF estimate (at the end 
of training) are similar. 

Note that the whole algorithm only takes O(D) time and space per step, the same as SGD. However, 
unlike SGD, there are no step-size parameters, since the diagonal covariance implicitly specifies the 
size of the update for each dimension. Furthermore, we get a posterior approximation, not just a 
point estimate. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e IW% IN Ie 


m= 
En 


Y OJIN JIN JIN JW JIN je je j= Je je j= IR je 
IS e e R E AS Ie le Sls lale le Is | 


382 


"=. Wo batch (Laplace) 1.25 
10 —+ wo online (adf) 
i 1.00 
0.5 ie wi O75 
B E £ 
5 0.04 2 2 0.50 
S = * 0.25 
-0.5 
-0.5 =.=- Wy batch (Laplace) 0.00 "=. Wz batch (Laplace) 
-1.0 —/ wy online (adf) — wp online (adf) 
-1.0 ii -0.25 
10 20 30 40 50 60 10 20 30 40 50 60 10 20 30 40 50 60 
number samples number samples number samples 
(a (b) (c) 


Figure 8.18: Same as Figure 8.17, except the order in which the data is visited is different. Generated by 
adf_logistic_regression_ demo.ipynb. 


Laplace Predictive distribution ADF Predictive distribution 


(a) (b) 


Figure 8.19: Predictive distribution for the binary logistic regression problem. (a) Result from Laplace 


= approximation. (b) Result from ADF at the final step. Generated by adf_logistic_ regression demo.ipynb. 


The overall approach is very similar to the generalized posterior linearization filter of Section 8.7.7.2, 


31 which uses quadrature (or the unscented transform) to compute a Gaussian approximation to the 
32 joint p(yz, w:|D1:t-1), from which we can easily compute p(w;|D1.,). However, ADF approximates 
23 the posterior rather than the joint, as explained in Section 8.9.1. 


8.9.4 ADF for online DNNs 


= Tn Section 17.5.3, we show how to use ADF to recursively approximate the posterior over the 
= parameters of a deep neural network in an online fashion. This generalizes Section 8.9.3 to the case 


= of nonlinear models. 


42 8.10 Other inference methods for SSMs 


44 There are a variety of other inference algorithms that can be applied to SSMs which we discuss 
45 elsewhere in this book. We give a very brief summary below, mostly focused on the case of offline 
46 inference (smoothing). 
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8.10.1 Expectation propagation 


In Section 10.7 we discuss the expectation propagation (EP) algorithm, which can be viewed as 
an iterative version of ADF (Section 8.9). In particular, at each step we combine each exact local 
likelihood factor with approximate factors from both the past filtering distribution and the future 
smoothed posterior; these factors are combined to compute the locally exact posterior, which is then 
projected back to the tractable family (e.g., Gaussian), before moving to the next time step. This 
process can be iterated for increased accuracy. In many cases the local EP update is intractable, 
but we can make a local Gaussian approximation, similar to the one in general Gaussian filtering 
(Section 8.7), as explained in [Kam-+22]. 


8.10.2 Variational inference 


EP can be viewed as locally minimizing the inclusive KL, Dpi (p(z:|y1:7) || ¢(2|y1-7)), for each time 


step t. An alternative approach is to globally minimize the exclusive KL, Dx (q(21-r|y1:7) || p(21-rly1:r)); 


this is called variational inference, and is explained in Chapter 10. The difference between these two 
objectives is discussed in more detail in Section 5.1.3.3, but from a practical point of view, the main 
advantage of VI is that we can derive a tractable lower bound to the objective, and can then optimize 
it using stochastic optimization. This method is guaranteed to converge, unlike EP. For more details 
on VI applied to SSMs (both state estimation and parameter estimation), see e.g., [CWS21; Cou+20; 
Cou+21; BFY20; FLMM21; Cam+21]. 


8.10.3 MCMC 


In Chapter 12 we discuss Markov chain Monte Carlo (MCMC) methods, which can be used to draw 
samples from intractable posteriors. In the case of SSMs, this includes both the distribution over 
states, p(z1-r|y1-r), and the distribution over parameters, p(@|y1:r). In some cases, such as when 
using HMMs or linear-Gaussian SSMs, we can perform blocked Gibbs sampling, in which we use 
forwards filtering backwards sampling to sample an entire sequence from p(z1-r|y1.7, 8), followed by 
sampling the parameters, p(0|Z1:7, Yı:r) (see e.g., [CK96; Sco02; CMR05] for details.) Alternatively 
we can marginalize out the hidden states and just compute the parameter posterior p(@|y1.7). When 
state inference is intractable, we can use gradient-based HMC methods (assuming the states are 
continuous), although this does not scale well to long sequences. 


8.10.4 Particle filtering 


In Section 13.2 we discuss particle filtering, which is a form of sequential Bayesian inference for SSMs 
which replaces the assumption that the posterior is (approximately) Gaussian with a more flexible 
representation, namely a set of weighted samples called “particles” (see e.g., [Aru+02; DJ11; NLS19]). 
Essentially the technique amounts to a form of importance sampling, combined with steps to prevent 
“particle impoverishment”, which refers to some samples receiving negligible weight because they are 
too improbable in the posterior (which grows with time). Particle filtering is widely used because 
it is very flexible, and has good theoretical properties. In practice it may require many samples to 
get a good approximation, but we can use heuristic methods, such as the extended or unscented 
Kalman filters, as proposal distributions, which can improve the efficiency significantly. In the offline 
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setting, we can use particle smoothing (Section 13.5) or SMC (sequential Monte Carlo) samplers 
(Section 13.6). 
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9 Inference for graphical models 


9.1 Introduction 


In this chapter we consider posterior inference (i.e., computing marginals, modes, samples, etc) for 
probability distributions that can be represented by a PGM with some kind of sparse graph structure 
(i.e., it is not a fully connected graph). The algorithms we discuss will leverage the conditional 
independence properties encoded in the graph structure (discussed in Chapter 4) in order to perform 
efficient inference. In particular, we will use the principle of dynamic programming (DP), which 
finds an optimal solution by solving subproblems and then combining them. 

DP can be implemented by computing local quantities for each node (or clique) in the graph, 
and then sending messages to neighboring nodes (or cliques) so that all nodes (cliques) can come 
to an overall consensus about the global solutions. Hence these are known as message passing 
algorithms. Each message can be intepreted as probability distribution about the value of a node 
given evidence from part of the graph. These distributions are often called belief states, so these 
algorithms are also called belief propagation (BP) algorithms. 

The methods we discuss generalize the algorithms from Chapter 8, which only work with chain 
structured graphical models, to work with PGMs with arbitrary graph structure. This requires that 
we specify a message passing schedule. For chain structured models, a natural approach is to 
send messages forwards in time, and then backwards in time, as we discussed in Chapter 8. We can 
generalize this approach to work with trees, as we discuss in Section 9.2. For general graphs, there 
may be cycles or loops, as we discuss in Section 9.3. However, sending messages on loopy graphs 
may give incorrect answers. In such cases, we may wish to convert the graph to a tree, and then 
send messages on it, using the methods discussed in Section 9.4 and Section 9.5. We can also pose 
the inference problem as an optimization problem, as we discuss in Section 9.6. 


9.2 Belief propagation on trees 


The forwards-backwards algorithm for HMMs (Section 8.2.4) and the Kalman smoother algorithm 
for LDS (Section 8.3.3) can both be interpreted as message passing algorithms applied to chain 
structured graphical models. In this section, we generalize these algorithms to work with trees. 
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(a) (b) (c) 


Figure 9.1: An undirected tree and two equivalent directed trees. 


9.2.1 Directed vs undirected trees 


Consider a pairwise undirected graphical model, which can be written as follows: 


p* (z) = p(zly) x II Ws(Zs|Ys) II Ws,t(Zs, Zt) (9.1) 


sEV (s,t)EE 


where Ws 14(Zs, 2) are the pairwise clique potential, one per edge, wWs(zs|ys) are the local evidence 
potentials, one per node, V is the set of nodes, and £ is the set of edges. (We will henceforth drop 
the conditioning on the observed values y for brevity.) 

Now suppose the corresponding graph structure is a tree, such as the one in Figure 9.la. We can 
always convert this into a directed tree by picking an arbitrary node as the root, and then “picking 
the tree up by the root” and orienting all the edges away from the root. For example, if we pick node 
1 as the root we get Figure 9.1b. This corresponds to the following directed graphical model: 


p“ (z) x p*(21)p" (z2|21)p* (23122) p* (z4|z2) (9.2) 


However, if we pick node 2 as the root, we get Figure 9.1c. This corresponds to the following directed 


~. graphical model: 


p“ (2) x p*(22)p" (21|22)p" (23|22) p* (2422) (9.3) 


Since these graphs express the same conditional independence properties, they represent the same 
family of probability distributions, and hence we are free to use any of these parameterizations. 

To make the model more symmetric, it is preferable to use an undirected tree. If we define the 
potentials as (possibly unnnormalized) marginals (i.e., Ys(Zs) X p* (zs) and Ws t(Zs, zt) = D* (Zs, 2t)); 


22 then we can write 


pex Jre) TY ee. (9.4) 


i ciee P (%s)P* (2) 


= For example, for Figure 9.la we have 


* (z1, 22)p* (Za, 23)p* (22, 24 
p* (21, 205 20, 4) o p” (z1)p" (22)p" (za)p" (24) EL Zee Cea, zade" (ea, za) 


PEP eaP eaP 2a)" eaP a) (e 
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9.2. BELIEF PROPAGATION ON TREES 


// Collect to root 
for each node s in post-order 


bela (5) o e (zs) Tsecn, Mi=e(zs) 
t = parent (s) 


Ms—t(Zt) = D Wst(Zs, zt)bels(zs) 


// Distribute from root 
for each node t in pre-order 
s = parent (t) 


msst(zt) = Dz, Vee (2s, zi) PGs 


mt—+s(Zs) 


bel: (zz) x bele(z)ms-sz(2t) 


Figure 9.2: Belief propagation on a pairwise, rooted tree. 


To see the equivalence with the directed representation, we can cancel terms to get 


p* (22, 23) p* (22, 24) 


a Omar ve 
= p“ (21)p* (z2|21)p* (23| 22)" (z4|22) (9.7) 
= p* (z2)p* (z1|22)p* (23|22)p* (zal 22) 9.8) 


where p*(z:|25) = p* (zs, 2)/p* (2s): 
Thus a tree can be represented as either an undirected or directed graph. Both representations 
can be useful, as we will see. 


9.2.2 Sum-product algorithm 


In this section, we assume that our model is an undirected tree, as in Equation (9.1). However, we 
will pick an arbitrary node as a root, and orient all the edges downwards away from this root, so that 
each node has a unique parent. For a directed, rooted tree, we can compute various node orderings. 
In particular, in a pre-order, we traverse from the root to the left subtree and then to right subtree, 
top to bottom. In a post-order, we traverse from the left subtree to the right subtree and then to 
the root, bottom to top. We will use both of these below. 

We now present the sum-product algorithm for trees. We first send messages from the leaves 
to the root. This is the generalization of the forwards pass from Section 8.2.2. Let m,_;(z;) denote 
the message from node s to node t. This summarizes the belief state about z; given all the evidence 
in the tree below the s — t edge. Consider a node s in the ordering. We update its belief state by 
combining the incoming messages from all its children with its own local evidence: 


bels (zs) « Hs(2s) [[ mis(s) (9.9) 


tEchs 


To compute the outgoing message that s should send to its parent t, we pass the local belief through 
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ws (Zs) 


Figure 9.3: Illustration of how the top-down message from s to t is computed during BP on a tree. The ui 
nodes are the other children of s, besides t. Square nodes represent clique potentials. 


the pairwise potential linking s and t, and then marginalize out s to get 


Ms-+t(2t) =X Yal Zs, 24)bels(zs) (9.10) 


Zs 


At the root of the tree, beli (z+) = p(z:|y) will have seen all the evidence. It can then send messages 
back down to the leaves. The message that s sends to its child t should be the product of all the 
messages that s received from all its other children u, passed through the pairwise potential, and 
then marginalized: 

Ms—t(2t) = 5 Ws(Zs) Yst Zs, 2t) ) lI Mus Zs) (9.11) 
Žt uEchs\t 


31 See Figure 9.3. Instead of multiplying all-but-one of the messages that s has received, we can multiply 


all of them and then divide out by the t > s message from child t. The advantage of this is that 


33 the product of all the messages has already been computed in Equation (9.9), so we don’t need to 
34 recompute that term. Thus we get 


bels (Zs 
Msl 2) = Lai Zs, Zt) ( ) (9.12) 


mMt+s(Zs) 


We can think of bel;(zs) as the new updated posterior p(zs|y) given all the evidence, and m:_,5(Zs) 
as the prior predictive p(z,|y; ), where y; is all the evidence in the subtree rooted at t. Thus the 
ratio contains the new evidence that t did not already know about from its own subtree. We use this 
to update the belief state at node t to get: 


bel; (z+) x beli (zt) Ms thz) (9.13) 


45 (Note that Equation (9.9) is a special case of this where we don’t divide out by Mms—+, since in the 
46 upwards pass, there is no incoming message from the parent.) This is analogous to the backwards 
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9.2. BELIEF PROPAGATION ON TREES 


smoothing equation in Equation (8.24), with a;(i) replaced by bel,(z; = i), A(i, j) replaced by 
Wee (2s = i, z = j), %41 (j) replaced by bels(zs = j), and a 41)¢(j) replaced by me+s(2s = j). 

See Figure 9.2 for the overall pseudocode. This can be generalized to directed trees with multiple 
root nodes (known as polytrees) as described in Supplementary Section 9.1.1. 


9.2.3 Max-product algorithm 


In Section 9.2.2 we described the sum-product algorithm, that computes the posterior marginals: 


bel; (k) = ya(k) = p(zi = kly) = X` plz = k, zily) (9.14) 


We can replace the sum operation with the max operation to get max-product belief propagation. 
The result of this computation are a set of max marginals for each node: 


G(k) = max p(z; = k, zily) (9.15) 

We can derive two different kinds of “MAP” estimates from these local quantities. The first is 
ĉi = argmax;, 7:(k); this is known as the maximizer of the posterior marginal or MPM estimate 
(see e.g., [MMP87; SM12]); let 2 = [21,..., Z_] be the sequence of such estimates. The second is 
Ži = argmax,, ¢i(k); we call this the maximizer of the max marginal or MMM estimate; let 
Z=[f,...,Zn,]. 

An interesting question is: what, if anything, do these estimates have to do with the “true” MAP 
estimate, z* = argmax, p(z|y)? We discuss this below. 


9.2.3.1 Connection between MMM and MAP 


In [Y W04], they showed that, if the max marginals are unique and computed exactly (e.g., if the 
graph is a tree), then Z = z*. This means we can recover the global MAP estimate by running max 
product BP and then setting each node to its local max (i.e., using the MMM estimate). 

However, if there are ties in the max marginals (corresponding to the case where there is more 
than one globally optimal solution), this “local stitching” process may result in global inconsistencies. 

If we have a tree-structured model, we can use a traceback procedure, analogous to the Viterbi 
algorithm (Section 8.2.7), in which we clamp nodes to their optimal values while working backwards 
from the root. For details, see e.g., [KF 09a, p569]. 

Unfortunately, traceback does not work on general graphs. An alternative, iterative approach, 
proposed in [YW04], is follows. First we run max product BP, and clamp all nodes which have unique 
max marginals to their optimal values; we then clamp a single ambiguous node to an optimal value, 
and condition on all these clamped values as extra evidence, and perform more rounds of message 
passing, until all ties are broken. This may require many rounds of inference, although the number 
of non-clamped (hidden) variables get reduced at each round. 


9.2.3.2 Connection between MPM and MAP 


In this section, we discuss the MPM estimate, Z, which computes the maximum of the posterior 
marginals. In general, this does not correspond to the MAP estimate, even if the posterior marginals 
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are exact. To see why, note that MPM just looks at the belief state for each node given all the visible 
evidence, but ignores any dependencies or constraints that might exist in the prior. 

To illustrate why this could be a problem, consider the error correcting code example from 
Section 5.5, where we defined p(z, y) = p(21)p(z2)p(z3|z1, 22) I, p(yilzi), where all variables are 
binary. The priors p(z1) and p(z2) are uniform. The conditional term p(z3|z1, z2) is deterministic, 
and computes the parity of (z1, 22). In particular, we have p(z3 = 1|z1, z2) = I (odd(z1, z2)), so that 
the total number of 1s in the block 21.3 is even. The likelihood terms p(y;|z;) represent a bit flipping 
noisy channel model, with noise level a = 0.2. 

Suppose we observe y = (1,0,0). In this case, the exact posterior marginals are as follows:' y1 = 
(0.3469, 0.6531], y2 = [0.6531, 0.3469], y3 = [0.6531, 0.3469]. The exact max marginals are all the same, 
namely ¢; = [0.3265, 0.3265]. Finally, the 3 global MAP estimates are z* € {[0, 0,0], [1, 1, 0], [1, 0, 1]}, 
each of which corresponds to a single bit flip from the observed vector. The MAP estimates are all 
valid code words (they have an even number of 1s), and hence are sensible hypotheses about the 
value of z. By contrast, the MPM estimate is Z = [1,0,0], which is not a legal codeword. (And in 
this example, the MMM estimate is not well defined, since the max marginals are not unique.) 

So, which method is better? This depends on our loss function, as we discuss in Section 34.1. If 
we want to minimize the prediction error of each z;, also called bit error, we should compute the 
MPM. If we want to minimize the prediction error for the entire sequence z, also called word error, 
we should use MAP, since this can take global constraints into account. 

For example, suppose we are performing speech recognition and someones says “recognize speech”. 
MPM decoding may return “wreck a nice beach”, since locally it may be that “beach” is the most 
probable interpretation of “speech” when viewed in isolation (see Figure 34.1). However, MAP 
decoding would infer that “recognize speech” is the more likely overall interpretation, by taking into 
account the language model prior, p(z). 

On the other hand, if we don’t have strong constraints, the MPM estimate can be more robust 
[MMP87; SM12], since it marginalizes out the other nodes, rather than maxing them out. For 
example, in the casino HMM example in Figure 8.4, we see that the MPM method makes 49 bit 
errors (out of a total possible of T = 300), and the MAP path makes 60 errors. 


31 9.2.3.3 Connection between MPE and MAP 


~ In the graphical models literature, computing the jointly most likely setting of all the latent variables, 
ma ee 
— literature, the term “MAP” is used to refer to the case where we maximize some of the hidden 
— variables, and marginalize (sum out) the rest. For example, if we maximize a single node, z;, but 
Z sum out all the others, z_;, we get the MPM 2; = argmax,, )°,_, p(zly). 


* 


= argmax, p(z|y), is known as the most probable explanation or MPE [Pea8s8]. In that 


We can generalize the MPM estimate to compute the best guess for a set of query variables Q, 


~~ given evidence on a set of visible variables V, marginalizing out the remaining variables R, to get 


z= arg max ) (2, zrlzv) (9.16) 
ZR 


(Here zp are called nuisance variables, since they are not of interest, and are not observed.) In 


= [Pea88], this is called a MAP estimate, but we will call it an MPM estimate, to avoid confusion with 
= the ML usage of the term “MAP” (where we maximize everything jointly). 


46 1. See error_correcting code_demo.ipynb for the code. 
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9.3 Loopy belief propagation 


In this section, we extend belief propagation to work on graphs with cycles or loops; this is called 
loopy belief propagation or LBP. Unfortunately, this method may not converge, and even if it 
does, it is not clear if the resulting estimates are valid. Indeed, Judea Pearl, who invented belief 
propagation for trees, wrote the following about loopy BP in 1988: 


When loops are present, the network is no longer singly connected and local propagation 
schemes will invariably run into trouble ... If we ignore the existence of loops and permit the 
nodes to continue communicating with each other as if the network were singly connected, 
messages may circulate indefinitely around the loops and the process may not converge to a 
stable equilibrium ... Such oscillations do not normally occur in probabilistic networks ... 
which tend to bring all messages to some stable equilibrium as time goes on. However, this 
asymptotic equilibrium is not coherent, in the sense that it does not represent the posterior 
probabilities of all nodes of the network — [Pea88, p.195] 


Despite these reservations, Pearl advocated the use of belief propagation in loopy networks as an 
approximation scheme (J. Pearl, personal communication). [MWJ99] found empirically that it works 
on various graphical models, and it is now used in many real world applications, some of which we 
discuss below. In addition, there is now some theory justifying its use in certain cases, as we discuss 
below. (For more details, see e.g., [Yed11].) 


9.3.1 Loopy BP for pairwise undirected graphs 


In this section, we assume (for notational simplicity) that our model is an undirected pairwise PGM, 
as in Equation (9.1). However, unlike Section 9.2.2, we do not assume the graph is a tree. We can 
apply the same message passing equations as before. However, since there is no natural node ordering, 
we will do this in a parallel, asynchronous way. The basic idea is that all nodes receive messages from 
their neighbors in parallel, they then update their belief states, and finally they send new messages 
back out to their neighbors. This message passing process repeats until convergence. This kind of 
computing architecture is called a systolic array, due to its resemblance to a beating heart. 

More precisely, we initialize all messages to the all 1’s vector. Then, in parallel, each node absorbs 
messages from all its neighbors using 


bele(zs) x He(zs) [| mis(s) (9.17) 


tEnbrs 


Then, in parallel, each node sends messages to each of its neighbors: 


Mshz) = 5 Ws(Zs)Wst(Zs, Zt) Il Mu+s(Zs) (9.18) 


Zs uEnbrs\t 


The Mms—+ message is computed by multiplying together all incoming messages, except the one sent by 
the recipient, and then passing through the Yst potential. We continue this process until convergence. 
If the graph is a tree, the method is guaranteed to converge after D(G) iterations, where D(G) is the 
diameter of the graph, that is, the largest distance between any two nodes. 
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Figure 9.4: Message passing on a bipartite factor graph. Square nodes represent factors, and circles represent 
variables. The yi nodes correspond to the neighbors x’, of f other than x. From Figure 6 of [KFLO1]. Used 
with kind permission of Brendan Frey. 


9.3.2 Loopy BP for factor graphs 


To implement loopy BP for general graphs, including those with higher-order clique potentials 
(beyond pairwise), it is useful to use a factor graph representation described in Section 4.6.1. In 
this section, we summarize the BP equations for the bipartite version of factor graphs, as derived in 
[KFLO1].? For a version that works for Forney factor graphs, see [Loe+07]. 

In the case of bipartite factor graphs, we have two kinds of messages: variables to factors 


M+ f(x) = II Mnpr+2 (2) (9.19) 
h€nbr(x)\{f} 


and factors to variables 


Mfa(L) = 5 f(x, 2’) II Marpha’) (9.20) 
a x’ Enbr(f)\{x} 


Here nbr(z) are all the factors that are connected to variable x, and nbr( f) are all the variables that 


35 are connected to factor f. These messages are illustrated in Figure 9.4. At convergence, we can 
36 compute the final beliefs as a product of incoming messages: 


bel(x) « II tts +_(2) (9.21) 
f€nbr(x) 


The order in which the messages are sent can be determined using various heuristics, such as 


~~ computing a spanning tree, and picking an arbitrary node as root. Alternatively, the update 


ordering can be chosen adaptively using residual belief propagation [EMK06]. Or fully parallel, 
asynchronous implementations can be used. 


2. For an efficient JAX implementation of these equations for discrete factor graphs, see https://github.com/ 


46 vicariousinc/PGMax. For the Gaussian case, see https://github.com/probml/pgm- jax. 
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Figure 9.5: Interpolating noisy data using Gaussian belief propagation applied to a 1d MRF. Generated by 
gauss-bp-1d-line.ipynb. 


9.3.3 Gaussian belief propagation 


It is possible to genereralize (loopy) belief propagation to the Gaussian case, by using the “calculus for 
linear Gaussian models” in Section 2.2.7 to compute the messages and beliefs. Note that computing 
the posterior mean in a linear-Gaussian system is equivalent to solving a linear system, so these 
methods are also useful for linear algebra. See e.g., [PL03; Bic09; Du+18] for details. 

As an example of Gaussian BP, consider the problem of interpolating noisy data in 1d, as discussed 
in [OED21]. In particular, let f : R —> R be an unknown function for which we get N noisy 
measurements y; at locations 2;. We want to estimate z; = f(g;) at G grid locations g;. Let x; be 
the closest location to gi. Then we assume the measurement factor is as follows: 


1 


bi(zi-1, %1) = = (Gi — yi)? (9.22) 
Ji = (1 — Yi) zi-1 + V2 (9.23) 
% = so (9.24) 


Here 7; is the predicted measurement. The potential function makes the unknown function values 
z;—-1 and z; move closer to the observation, based on how far these grid points are from where the 
measurement was taken. In addition, we add a pairwise smoothness potential, that encodes the prior 
that z; should be close to z;-1 and 241: 


1 5 
bi(zi-1, 21) = a3 bi (9.25) 
Ò; = 2% — 2-1 (9.26) 
The overall model is 
G 
p(z|a,y,g,07,77) « [[ ei, 24) i (2-1, %) (9.27) 
i=1 
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4 
(a) 


Figure 9.6: (a) A simple loopy graph. (b) The computation tree, rooted at node 1, after 4 rounds of message 
passing. Nodes 2 and 3 occur more often in the tree because they have higher degree than nodes 1 and 2. 
From Figure 8.2 of [WJ08]. Used with kind permission of Martin Wainwright. 


Suppose the true underlying function is a sine wave. We show some sample data in Figure 9.5(a). 
We then apply Gaussian BP. Since this model is a chain, and the model is linear-Gaussian, the 
resulting posterior marginals, shown in Figure 9.5(b), are exact. We see that the method has inferred 
the underlying sine shape just based on a smoothness prior. 

To perform message passing in models with non-linear (but Gaussian) potentials, we can generalize 
the extended Kalman filter techniques from Section 8.5.2 and the moment matching techniques 
(based on quadrature / sigma points) from Section 8.7 and Section 8.7.4 from chains to general 
factor graphs (see e.g., [MHH14; PHR18; HPR19]). To extend to the non-Gaussian case, we can 
use non-parametric BP or particle BP (see e.g., [Sud+03; Isa03; Sud+10; Pac+14]), which uses 
ideas from particle filtering (Section 13.2). 


9.3.4 Convergence 


~. Loopy BP may not converge, or may only converge slowly. In this section, we discuss some techniques 
_ that increase the chances of convergence, and the speed of convergence. 


= 9.3.4.1 When will LBP converge? 


37 The details of the analysis of when LBP will converge are beyond the scope of this chapter, but 
38 we briefly sketch the basic idea. The key analysis tool is the computation tree, which visualizes 
39 the messages that are passed as the algorithm proceeds. Figure 9.6 gives a simple example. In the 
40 first iteration, node 1 receives messages from nodes 2 and 3. In the second iteration, it receives one 
41 message from node 3 (via node 2), one from node 2 (via node 3), and two messages from node 4 (via 
42 nodes 2 and 3). And so on. 


The key insight is that T iterations of LBP is equivalent to exact computation in a computation 


44 tree of height T + 1. If the strengths of the connections on the edges is sufficiently weak, then the 
45 influence of the leaves on the root will diminish over time, and convergence will occur. See [MK05; 
46 WJO08] and references therein for more information. 
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Figure 9.7: Illustration of the behavior of loopy belief propagation on an 11 x 11 Ising grid with random 
potentials, wij ~ Unif(—C,C), where C = 11. For larger C, inference becomes harder. (a) Percentage of 
messages that have converged vs time for 3 different update schedules: Dotted = damped synchronous (few 
nodes converge), dashed = undamped asychnronous (half the nodes converge), solid = damped asychnronous 
(all nodes converge). (b-f) Marginal beliefs of certain nodes vs time. Solid straight line = truth, dashed = 
sychronous, solid = damped asychronous. From Figure 11.C.1 of [KF 09a]. Used with kind permission of 
Daphne Koller. 


9.3.4.2 Making LBP converge 


Although the theoretical convergence analysis is very interesting, in practice, when faced with a 
model where LBP is not converging, what should we do? 

One simple way to increase the chance of convergence is to use damping. That is, at iteration k, 
we use an update of the form 


me,.(ts) = AM: +s(s) + (1 — AMEs (as) (9.28) 


where m:-,5(5) is the standard undamped message, where 0 < A < 1 is the damping factor. Clearly 
if A = 1 this reduces to the standard scheme, but for À < 1, this partial updating scheme can help 
improve convergence. Using a value such as A ~ 0.5 is standard practice. The benefits of this 
approach are shown in Figure 9.7, where we see that damped updating results in convergence much 
more often than undamped updating (see [ZLG20] for some analysis of the benefits of damping). 

It is possible to devise methods, known as double loop algorithms, which are guaranteed 
to converge to a local minimum of the same objective that LBP is minimizing [Yui01; WTO1]. 
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Unfortunately, these methods are rather slow and complicated, and the accuracy of the resulting 
marginals is usually not much greater than with standard LBP. (Indeed, oscillating marginals is 
sometimes a sign that the LBP approximation itself is a poor one.) Consequently, these techniques 
are not very widely used. 


9.3.4.3 Increasing the convergence rate with adaptive scheduling 


The standard approach when implementing LBP is to perform synchronous updates, where all 
nodes absorb messages in parallel, and then send out messages in parallel. That is, the new messages 
at iteration k + 1 are computed in parallel using 


mie = (fi(m*),..., fa(m*)) (9.29) 


where E is the number of edges, and f;(m) is the function that computes the message for edge i 
given all the old messages. This is analogous to the Jacobi method for solving linear systems of 
equations. 

It is well known [Ber97] that the Gauss-Seidel method, which performs asynchronous updates 
in a fixed round-robin fashion, converges faster when solving linear systems of equations. We can 
apply the same idea to LBP, using updates of the form 

m**! = fi (im j< i}, {m} :j >i} (9.30) 
where the message for edge 7 is computed using new messages (iteration k + 1) from edges earlier in 
the ordering, and using old messages (iteration k) from edges later in the ordering. 

This raises the question of what order to update the messages in. One simple idea is to use a fixed 


= or random order. The benefits of this approach are shown in Figure 9.7, where we see that (damped) 
= asynchronous updating results in convergence much more often than synchronous updating. 


However, we can do even better by using an adaptive ordering. The intuition is that we should 


= focus our computational efforts on those variables that are most uncertain. [EMK06] proposed 
= a technique known as residual belief propagation, in which messages are scheduled to be sent 
°= according to the norm of the difference from their previous value. That is, we define the residual of 
°< new message Ms—+ at iteration k to be 


r(8, t, K) = || 10g ms se — log m yslloo = max [log 24) (9.31) 
J mS (j) 


37 We can store messages in a priority queue, and always send the one with highest residual. When a 
38 message is sent from s to t, all of the other messages that depend on m,_,; (i.e., messages of the form 
39 Mu Where u € nbr(t) \ s) need to be recomputed; their residual is recomputed, and they are added 
40 back to the queue. In [EMK06], they showed (experimentally) that this method converges more 
41 often, and much faster, than using sychronous updating, asynchronous updating with a fixed order. 


A refinement of residual BP was presented in [SM07]. In this paper, they use an upper bound on 


43 the residual of a message instead of the actual residual. This means that messages are only computed 
44 if they are going to be sent; they are not just computed for the purposes of evaluating the residual. 
45 This was observed to be about five times faster than residual BP, although the quality of the final 
46 results is similar. 
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Figure 9.8: (a) Clusters superimposed on a 3x3 lattice graph. (b) Corresponding hyper-graph. Nodes represent 
clusters, and edges represent set containment. From Figure 4.5 of [WJ08]. Used with kind permission of 
Martin Wainwright. 


9.3.5 Accuracy 


For a graph with a single loop, one can show that the max-product version of LBP will find the 
correct MAP estimate, if it converges [Wei00]. For more general graphs, one can bound the error in 
the approximate marginals computed by LBP, as shown in [WJW03; IFW05; Vin+10b]. 

Much stronger results are available in the case of Gaussian models. In particular, it can be shown 
that, if the method converges, the means are exact, although the variances are not (typically the 
beliefs are over confident). See e.g., [WF01la; JMW06; Bic09; Du+18] for details. 


9.3.6 Generalized belief propagation 


We can improve the accuracy of loopy BP by clustering together nodes that form a tight loop. This 
is known as the cluster variational method, or generalized belief propagation [YF W00]. 

The result of clustering is a hyper-graph, which is a graph where there are hyper-edges between 
sets of vertices instead of between single vertices. Note that a junction tree (Section 9.5) is a kind of 
hyper-graph. We can represent a hyper-graph using a poset (partially ordered set) diagram, where 
each node represents a hyper-edge, and there is an arrow e; —> e2 if e2 C e1. See Figure 9.8 for an 
example. 

If we allow the size of the largest hyper-edge in the hyper-graph to be as large as the treewidth 
of the graph, then we can represent the hyper-graph as a tree, and the method will be exact, just 
as LBP is exact on regular trees (with treewidth 1). In this way, we can define a continuum of 
approximations, from LBP all the way to exact inference. See Supplementary Section 10.4.3.3 for 
more information. 


9.3.7 Convex BP 


In Supplementary Section 10.4.3 we analyse LBP from a variational perspective, and show that the 
resulting optimization problem, for both standard and generalized BP, is non-convex. However it is 
possible to create a version of convex BP, as we explain in Supplementary Section 10.4.4, which 
has the advantage that it will always converge. 
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Figure 9.9: (a) A simple factor graph representation of a (2,3) low-density parity check code. Each message 
bit (hollow round circle) is connected to two parity factors (solid black squares), and each parity factor is 
connected to three bits. Each parity factor has the form YPstu(£s, £t, tu) = I (£s @ Tt Q £u = 1), where Q is 
the zor operator. The local evidence factors for each hidden node are not shown. (b) A larger example of a 
random LDPC code. We see that this graph is “locally tree-like”, meaning there are no short cycles; rather, 
each cycle has length ~ logm, where m is the number of nodes. This gives us a hint as to why loopy BP 
works so well on such graphs. (Note, however, that some error correcting code graphs have short loops, so this 
is not the full explanation.) From Figure 2.9 from [WJ08]. Used with kind permission of Martin Wainwright. 


9.3.8 Application: error correcting codes 


LBP was first proposed by Judea Pearl in his 1988 book [Pea88]. He recognized that applying BP to 
loopy graphs might not work, but recommended it as a heuristic. 

However, the main impetus behind the interest in LBP arose when McEliece, MacKay, and Cheng 
[MMC98] showed that a popular algorithm for error correcting codes, known as turbocodes [BGT93], 
could be viewed as an instance of LBP applied to a certain kind of graph. 

We introduced error correcting codes in Section 5.5. Recall that the basic idea is to send the 
source message Œ € {0,1}™ over a noisy channel, and for the receiver to try to infer it given noisy 
measurements y € {0,1} or y € R™. That is, the receiver needs to compute x* = argmax, p(x|y) = 


3, argmax, p(x). 


It is standard to represent p(a) as a factor graph (Section 4.6.1), which can easily represent any 
deterministic relationships (parity constraints) between the bits. A factor graph is a bipartite graph 


37 With x; nodes on one side, and factors on the other. A graph in which each node is connected to 


n factors, and in which each factor is connected to k nodes, is called an (n,k) code. Figure 9.9(a) 


39 Shows a simple example of a (2,3) code, where each bit (hollow round circle) is connected to two 


parity factors (solid black squares), and each parity factor is connected to three bits. Each parity 
factor has the form 


1 iff,@®2,®@2%,=1 


0 otherwise (9.32) 


Wstu(2s, Tt, Lu) = { 


45 If the degrees of the parity checks and variable nodes remain bounded as the blocklength m increases, 
46 this is called a low-density parity check code, or LDPC code. (Turbocodes are constructed in 
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Figure 9.10: Factor graphs for affinity propagation. Circles are variables, squares are factors. Each ci node 
has N possible states. From Figure S2 of [FD07a]. Used with kind permission of Brendan Frey. 


a similar way.) 

Figure 9.9(b) shows an example of a randomly constructed LDPC code. This graph is “locally 
tree-like”, meaning there are no short cycles; rather, each cycle has length ~ logm. This fact is 
important to the success of LBP, which is only guaranteed to work on tree-structured graphs. Using 
methods such as these, people have been able to approach the lower bound in Shannon’s channel 
coding theorem, meaning they have produced codes with very little redundancy for a given amount 
of noise in the channel. See e.g., [MMC98; Mac03] for more details. Such codes are widely used, e.g., 
in modern cellphones. 


9.3.9 Application: Affinity propagation 


In this section, we discuss affinity propagation |FD07a], which can be seen as an improvement 
to K-medoids clustering, which takes as input a pairwise similarity matrix. The idea is that each 
data point must choose another data point as its exemplar or centroid; some data points will choose 
themselves as centroids, and this will automatically determine the number of clusters. More precisely, 


let c; € {1,..., N} represent the centroid for datapoint i. The goal is to maximize the following 
function 
N N 
I(c) =X Sica) +X skle) (9.33) 
i=1 k=1 


where S(i,c;) is the similarity between data point i and its centroid c;. The second term is a penalty 
term that is —oo if some data point i has chosen k as its exemplar (i.e., c; = k), but k has not chosen 
itself as an exemplar (i.e., we do not have c = k). More formally, 


TOER —oo if ck Ak but Ji : ci = k 


0 otherwise (aaa) 
This encourages “representative” samples to vote for themselves as centroids, thus encouraging 
clustering behavior. 

The objective function can be represented as a factor graph. We can either use N nodes, each 
with N possible values, as shown in Figure 9.10, or we can use N? binary nodes (see [GF09] for the 
details). We will assume the former representation. 
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Figure 9.11: Example of affinity propagation. Each point is colored coded by how much it wants to be 
an exemplar (red is the most, green is the least). This can be computed by summing up all the incoming 
availability messages and the self-similarity term. The darkness of thei —> k arrow reflects how much point i 
wants to belong to exemplar k. From Figure 1 of [FDO7a]. Used with kind permission of Brendan Frey. 


We can find a strong local maximum of the objective by using max-product loopy belief propagation 


31 (Section 9.3). Referring to the model in Figure 9.10, each variable node c; sends a message to each 
32 factor node 6,. It turns out that this vector of N numbers can be reduced to a scalar message, 
33 denoted r;_,,, known as the responsibility. This is a measure of how much 7 thinks k would make a 
34 good exemplar, compared to all the other exemplars i has looked at. In addition, each factor node 
35 ôk sends a message to each variable node c;. Again this can be reduced to a scalar message, ai, 
36 known as the availability. This is a measure of how strongly k believes it should an exemplar for å, 


based on all the other data points k has looked at. 
As usual with loopy BP, the method might oscillate, and convergence is not guaranteed. However, 


39 by using damping, the method is very reliable in practice. If the graph is densely connected, message 


passing takes O(N?) time, but with sparse similarity matrices, it only takes O(E) time, where E is 


41 the number of edges or non-zero entries in S. 


The number of clusters can be controlled by scaling the diagonal terms S(i, i), which reflect how 


43 much each data point wants to be an exemplar. Figure 9.11 gives a simple example of some 2d data, 
44 where the negative Euclidean distance was used to measured similarity. The S(i, i) values were set 
45 to be the median of all the pairwise similarities. The result is 3 clusters. Many other results are 
46 reported in [FD07a], who show that the method significantly outperforms K-medoids. 
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9.4. THE VARIABLE ELIMINATION (VE) ALGORITHM 


9.3.10 Emulating BP with graph neural nets 


There is a close connection between message passing in PGMs and message passing in graph neural 
networks (GNNs), which we discuss in Section 16.3.6. However, for PGMs, the message computations 
are computing using (non-learned) update equations that work for any model; all that is needed 
is the graph structure G, model parameters 0, and evidence v. By contrast, GNNs are trained to 
emulate specific functions using labeled input-output pairs. 

It is natural to wonder what happens if we train a GNN on the exact posterior marginals derived 
from a small PGM, and then apply that trained GNN to a different test PGM. In [Yoo+18; Zha+19d], 
they show this method can work quite well if the test PGM is similar in structure to the one used for 
training. 

An alternative approach is to start with a known PGM, and then “unroll” the BP message passing 
algorithm to produce a layered feedforward model, whose connectivity is derived from the graph. The 
resulting network can then be trained discriminatively for some end-task (not necessarily computing 
posterior marginals). Thus the BP procedure applied to the PGM just provides a way to design the 
neural network structure. This method is called deep unfolding (see e.g., [HLRW14]), and can 
often give very good results. (See also [SW20] for a more recent version of this approach, called 
“neural enhanced BP”.) 

These neural methods are useful if the PGM is fixed, and we want to repeatedly perform inference 
or prediction with it, using different values of the evidence, but where the set of nodes which are 
observed is always the same. This is an example of amortized inference, where we train a model 
to emulate the results of running an iterative optimization scheme (see Section 10.3.6 for more 
discussion). 


9.4 The variable elimination (VE) algorithm 


In this section, we discuss an algorithm to compute a posterior marginal p(zg|y) for any query set 
Q, assuming p is defined by a graphical model. Unlike loopy BP, it is guaranteed to give the correct 
answers even if the graph has cycles. We assume all the hidden nodes are discrete, although a version 
of the algorithm can be created for the Gaussian case by using the rules for sum and product defined 
in Section 2.2.7. 


9.4.1 Derivation of the algorithm 


We will explain the algorithm by applying it to an example. Specifically, we consider the student 
network from Section 4.2.2.2. Suppose we want to compute p(J = 1), the marginal probability that 
a person will get a job. Since we have 8 binary variables, we could simply enumerate over all possible 
assignments to all the variables (except for J), adding up the probability of each joint instantiation: 


PF) =S >>> >_> pC, D,1,G, 8, L, J, H) (9.35) 
I DC 


L S G H 


However, this would take O(2") time. We can be smarter by pushing sums inside products. This 
is the key idea behind the variable elimination algorithm [ZP96], also called bucket elimination 
[Dec96], or, in the context of genetic pedigree trees, the peeling algorithm [CTS78]. 
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Figure 9.12: Example of the elimination process, in the order C, D,I,H,G,S,L. When we eliminate I (figure 
c), we add a fill-in edge between G and S, since they are not connected. Adapted from Figure 9.10 of [KF 09a]. 


In our example, we get 


p(J) p(C, D,1,G,S, L, J, H) 


L,S,G,H,I,D,C 


voa(C)bo (D, C)brDvalG, I, D)ws(S, Ibi (L, G) 
L,S,G,H,1,D,C 


x wy(J, L, S)ba (A, G, J) 
= $ V(I, L, 8) 90 br (L£,G) X bu (AG, J) $ bs(S, dr) 
L,S G H I 


x X` Ya(G,I, D) X` bo(C)bn(D,C) 
D C 


We now evaluate this expression, working right to left as shown in Table 9.1. First we multiply 


31 together all the terms in the scope of the $` o operator to create the temporary factor 


T1(C, D) = ¢o(C)bn(D, C) (9.36) 


Then we marginalize out C to get the new factor 


71(D) = $ > 11(C, D) (9.37) 
C 


Next we multiply together all the terms in the scope of the X` p operator and then marginalize out 
to create 


7(G, I, D) = Ya(G, 1, D) (D) (9.38) 
72(G,1) = 5° 73(G, 1, D) (9.39) 
D 
And so on. 
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So do bs, L, S) X vi (L,@) X bn (AG, J) >> os (8, Dvr) X ba (G1, D) >> vo(C)bv(D, C) 
L S H Cc 


G I D 
71(D) 


Soyer, S) So bt (LG) >> oH (4G, J) Yo ys(9, Dyr) X va(G 1, D)r1(D) 
L S G H D 


4 
ciuqxsc—qoxr 
72(G,1) 


SOSO bs, L, 8) X vi (LG) X bn (4G, J) >> vs (8, Dvr (1)72(G, I) 
L Ss G A I 


N, 
73(G,S) 


DED HS £8) 35 or L, G) J vn (H, G, J) 73(G, S) 
L 8 G H 


m 
T4(G,J) 


Yd bs, L, S) X br (L, G)ra(G, J)r3 (G, S) 
L S G 


T5(J,L,S) 


35S bs, L, S)rs(J, L, S) 
L 8 


N, 
T6(J,L) 


>> r(J, L) 
L 
S 
T7(J) 


Table 9.1: Eliminating variables from Figure 4.38 in the order C, D, I, H,G, S, L to compute P(J). 


The above technique can be used to compute any marginal of interest, such as p(J) or p(J, H). To 
compute a conditional, we can take a ratio of two marginals, where the visible variables have been 
clamped to their known values (and hence don’t need to be summed over). For example, 


dy p(J=j',I=1,H =0) 


(9.40) 


p(J =j = 1, H =0) 


9.4.2 Computational complexity of VE 


The running time of VE is clearly exponential in the size of the largest factor, since we have to sum 
over all of the corresponding variables. Some of the factors come from the original model (and are 
thus unavoidable), but new factors may also be created in the process of summing out. For example, 
in Table 9.1, we created a factor involving G, I and S; but these nodes were not originally present 
together in any factor. 

The order in which we perform the summation is known as the elimination order. This can 
have a large impact on the size of the intermediate factors that are created. For example, consider 
the ordering in Table 9.1: the largest created factor (beyond the original ones in the model) has size 
3, corresponding to 75(J, L, S). Now consider the ordering in Table 9.2: now the largest factors are 
1(1, D, L, J, H) and 12(D, L, S, J, H), which are much bigger. 
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> > vp(D,C)5 ` ` vs, L, S) X Yr(I)ys(S,I)Y balG, I, DiyL(L, G)pu (H,G, J) 
D g H L § I 
E 71(1,D,L,J,H) 


> do vd, C) dd vs (UL, Doa Ws (S, I)ri (1, D, L, J, H) 
DC H L 
72(D,L,S,J,H) 
Dodo ov(D,C) 35 I D bs, L, S)72(D, L, S, J, H) 
DC H L S§ 
73(D,L,J,H) 
do bd(D,C) >) >) 3D, L, J, H) 
DC A L 
t4(D,J,H) 


So do bp(D, C) >> (D, J, H) 
DC E 
T5(D,J) 
>> do bp (D, C)r5(D, J) 
D © 


T6(D,J) 


>> (D, J) 
D 


t7(J) 


Table 9.2: Eliminating variables from Figure 4.38 in the order G,I,S,L,H,C,D. 


We can determine the size of the largest factor graphically, without worrying about the actual 
numerical values of the factors, by running the VE algorithm “symbolically”. When we eliminate a 


30 variable z;, we connect together all variables that share a factor with z, (to reflect the new temporary 
31 factor 7;). The edges created by this process are called fill-in edges. For example, Figure 9.12 
32 shows the fill-in edges introduced when we eliminate in the C, D,I,... order. The first two steps do 
33 not introduce any fill-ins, but when we eliminate J, we connect G and S, to capture the temporary 


factor 
ri(G, S, I) = ¥s(S, Dbr(Dra(G, 1) (9.41) 


Let G be the (undirected) graph induced by applying variable elimination to G using elimination 
ordering <. The temporary factors generated by VE correspond to maximal cliques in the graph 
G. For example, with ordering (C, D, I, H, GŒ, S, L), the maximal cliques are as follows: 


{C, D}, {D, I, G}, {G, L, S, J}, {G, J, H}, {G, I, S} (9.42) 


It is clear that the time complexity of VE is 


Y Kl (9.43) 


cEC(G.) 
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where C(G) are the (maximal) cliques in graph G, |c| is the size of the clique c, and we assume for 
notational simplicity that all the variables have K states each. 

Let us define the induced width of a graph given elimination ordering <, denoted w<, as the 
size of the largest factor (i.e., the largest clique in the induced graph ) minus 1. Then it is easy to see 
that the complexity of VE with ordering < is O(K“<*'). The smallest possible induced width for a 
graph is known at its treewidth. Unfortunately finding the corresponding optimal elimination order 
is an NP-complete problem [Yan81; ACP87]. See Section 9.4.3 for a discussion of some approximate 
methods for finding good elimination orders. 


9.4.3 Picking a good elimination order 


Many algorithms take time (or space) which is exponential in the tree width of the corresponding 
graph. For example, this applies to Cholesky decompositions of sparse matrices, as well as to einsum 
contractions (see https://github.com/dgasmith/opt_einsum). Hence we would like to find an 
elimination ordering that minimizes the width. We say that an ordering 7 is a perfect elimination 
ordering if it does not introduce any fill-in edges. Every graph that is already triangulated (e.g., a 
tree) has a perfect elimination ordering. We call such graphs decomposable. 

In general, we will need to add fill-in edges to ensure the resulting graph is decomposable. Different 
orderings can introduce different numbers of fill-in edges, which affects the width of the resulting 
chordal graph; for example, compare Table 9.1 to Table 9.2. 

Choosing an elimination ordering with minimal width is NP-complete [Yan81; ACP87]. It is 
common to use greedy approximation known as the min-fill heuristic, which works as follows: 
eliminate any node which would not result in any fill-ins (i.e., all of whose uneliminated neighbors 
already form a clique); if there is no such node, eliminate the node which would result in the minimum 
number of fill-in edges. When nodes have different weights (e.g., representing different numbers of 
states), we can use the min-weight heuristic, where we try to minimize the weight of the created 
cliques at each step. 

Of course, many other methods are possible. See [Heg06] for a general survey. [Kja90; Kja92] 
compared simulated annealing with the above greedy method, and found that it sometimes works 
better (although it is much slower). [MJ97] approximate the discrete optimization problem by a 
continuous optimization problem. [BG96] present a randomized approximation algorithm. [Gil88] 
present the nested dissection order, which is always within O(log N) of optimal. [Ami01] discuss 
various constant-factor appoximation algorithms. [Dav-+04] present the AMD or approximate 
minimum degree ordering algorithm, which is implemented in Matlab.* The METIS library can 
be used for finding elimination orderings for large graphs; this implements the nested dissection 
algorithm [GT86]. For a planar graph with N nodes, the resulting treewidth will have the optimal 
size of O(N?/2). 


9.4.4 Computational complexity of exact inference 


We have seen that variable elimination takes O(N Kt") time to compute the marginals for a graph 
with N nodes, and treewidth w, where each variable has K states. If the graph is densely connected, 
then w = O(N), and so inference will take time exponential in N. 


3. See the description of the symamd command at https://bit.ly/31N6E2b. (“sym” stands for symbolic, “amd” stands 
approximate minimum degree.) 
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Figure 9.13: Encoding a 8-SAT problem on n variables and m clauses as a DGM. The Qs variables are 
binary random variables. The C; variables are deterministic functions of the Qs’s, and compute the truth 
value of each clause. The Ay nodes are a chain of AND gates, to ensure that the CPT for the final x node has 
bounded size. The double rings denote nodes with deterministic CPDs. From Figure 9.1 of [KF 09a]. Used 
with kind permission of Daphne Koller. 


Of course, just because some particular algorithm is slow doesn’t mean that there isn’t some 
smarter algorithm out there. Unfortunately, this seems unlikely, since it is easy to show that exact 
inference for discrete graphical models is NP-hard [DL93]. The proof is a simple reduction from the 
satisfiability problem. In particular, note that we can encode any 3-SAT problem as a DPGM with 
deterministic links, as shown in Figure 9.13. We clamp the final node, x, to be on, and we arrange 
the CPTs so that p(x = 1) > 0 iff there is a satisfying assignment. Computing any posterior marginal 
requires evaluating the normalization constant, p(x = 1), so inference in this model implicitly solves 
the SAT problem. 

In fact, exact inference is #P-hard [Rot96], which is even harder than NP-hard. The intuitive 


23 reason for this is that to compute the normalizing constant, we have to count how many satisfying 


assignments there are. (By contrast, MAP estimation is provably easier for some model classes 
[GPS89], since, intuitively speaking, it only requires finding one satisfying assignment, not counting 


31 all of them.) Furthermore, even approximate inference is computationally hard in general [DL93; 


Rot96]. 

The above discussion was just concerned with inferring the states of discrete hidden variables. 
When we have continuous hidden variables, the problem can be even harder, since even a simple 
two-node graph, of the form z — y, can be intractable to invert if the variables are high dimensional 


36 and do not have a conjugate relationship (Section 3.2). Inference in mixed discrete-continuous models 


can also be hard [LPO]. 
As a consequence of these hardness results, we often have to resort to approximate inference 
methods, such as variational inference (Chapter 10) and Monte Carlo inference (Chapter 11). 


— 9.4.5 Drawbacks of VE 


43 Consider using VE to compute all the marginals in a chain-structured graphical model, such as an 


HMM. We can easily compute the final marginal p(zr|y) by eliminating all the nodes zı to z7_1 
in order. This is equivalent to the forwards algorithm, and takes O(K?T) time, as we discussed in 
Section 8.2.4. But now suppose we want to compute p(zr_ily). We have to run VE again, at a cost 
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M33(X3) mM;4(X4) 


Figure 9.14: Sending multiple messages along a tree. (a) zı is root. (b) z2 is root. (c) z4 is root. (d) All of 
the messages needed to compute all singleton marginals. Adapted from Figure 4.8 of [Jor07]. 


of O(K?T) time. So the total cost to compute all the marginals is O(K?T?). However, we know that 
we can solve this problem in O(K?T) using the forwards-backwards, as we discussed in Section 8.2.4. 
The difference is that FB caches the messages computed on the forwards pass, so it can reuse them 
later. (Caching previously computed results is the core idea behind dynamic programming.) 

The same problem arises when applying VE to trees. For example, consider the 4-node tree in 
Figure 9.14. We can compute p(z1|y) by eliminating 22.4; this is equivalent to sending messages up 
to zı (the messages correspond to the r factors created by VE). Similarly we can compute p(z2|y), 
p(z3|y) and then p(z4|y). We see that some of the messages used to compute the marginal on one 
node can be re-used to compute the marginals on the other nodes. By storing the messages for later 
re-use, we can compute all the marginals in O(K?T) time, as we show in Section 9.2. 

The question is: how do we get these benefits of message passing on a tree when the graph is not a 
tree? We give the answer in Section 9.5. 


9.5 The junction tree algorithm (JTA) 


The junction tree algorithm or JTA is a generalization of variable elimination that lets us 
efficiently compute all the posterior marginals without repeating redundant work, by using dynamic 
programming, thus avoiding the problems mentioned in Section 9.4.5. The basic idea is to convert 
the graph into a special kind of tree, known as a junction tree (also called a join tree, or clique 
tree), and then to run belief propagation (message passing) on this tree. We can create the join 
tree by running variable elimination “symbolically”, as discussed in Section 9.4.2, and adding the 
generated fill-in edges to the graph. The resulting chordal graph can then be converted to a tree, as 
explained in Supplementary Section 9.2.1. Once we have a tree, we can perform message passing on 
it, using a variant of the method Section 9.2.2. See Supplementary Section 9.2.2 for details. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Jot TB IW IN IR 


IS IS TE I IS IS 18 le Ie IR ls la le le Is E Ie 


408 


9.6 Inference as optimization 


In this section, we discuss how to perform posterior inference by solving an optimization problem, 
which is often computationally simpler. See also Supplementary Section 9.3. 


9.6.1 Inference as backpropagation 


In this section, we discuss how to compute posterior marginals in a graphical model using automatic 
differentiation. For notational simplicity, we focus on undirected graphical models, where the joint 
can be represented as an exponential family (Section 2.3) follows: 


p(x) = ; TJ ¢c(w) = exp nT (a) — log A(n)) = exp(n"T (a) — log A(n)) (9.44) 


where We is the potential function for clique c, 7, are the natural parameters for clique c, T (xe) are 
the corresponding sufficient statistics, and A = log Z is the log partition function. 

We will consider pairwise models (with node and edge potentials), and discrete variables. The 
natural parameters are the node and edge log potentials, 7 = ({ns.;}, {1s,t,3,4}), and the sufficient 
statistics are node and edge indicator functions, T(x) = ({I (xs = j)}, {I (ws = j, a: = k)}). (Note: 
we use s,t E€ V to index nodes and j,k € ¥ to index states.) 

The mean of the sufficient statistics are given by 


H=E [T (æ)] = ({p(zs = ie {p(@s = j, t4 = k)}s t) = (Mag bas (Margie tarts) (9.45) 


The key result, from Equation (2.219), is that y = Vn A(n). Thus as long as we have a function that 
computes A(n) = log Z(7), we can use automatic differentiation (Section 6.2) to compute gradients, 
and then we can extract the corresponding node marginals from the gradient vector. If we have 
evidence (known values) on some of the variables, we simply “clamp” the corresponding entries to 0 
or 1 in the node potentials. 

The observation that probabilistic inference can be performed using automatic differentiation has 
been discovered independently by several groups (e.g., [Dar03; PD03; Eis16; ASM17]). It also lends 


32 itself to the development of differentiable approximations to inference (see e.g., [MB18]). 


9.6.1.1 Example: inference in a small model 


As a concrete example, consider a small chain structured model x, — #2 — x3, where each node has 


— K states. We can represent the node potentials as K x 1 tensors (table of numbers), and the edge 
— potentials by K x K tensors. The partition function is given by 


+ 


Z() = 5 pı (z1)Y2(£2)Y3 (£3) Y12(21, T2)}p23 (2, £3) (9.46) 


T1,%2,%3 


Let 7 = log(p) be the log potentials, and A(n) = log Z(ņ) be the log partition function. We can 


43 compute the single node marginals p, = p(x, = 1: K) using u, = Vn, A(n), and the pairwise 


marginals u, (j, k) = p(zs = j, £ = k) using Hst = Vn, ,A()- 
We can compute the partition function Z efficiently use numpy’s einsum function, which imple- 
ments tensor contraction using Einstein summation notation. We label each dimension of the tensors 
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by A, B and C, so einsum knows how to match things up. We then compute gradients using an 
auto-diff library.t The result is that inference can be done in two lines of Python code, as shown in 
Listing 9.1: 


Listing 9.1: Computing marginals from derivative of log partition function 
import jax.numpy as jnp 
from jax import grad 


logZ_fun = lambda logpots: np.log(jnp.einsum("A,B,C,AB,BC", 
*[jnp.exp(lp) for lp in logpots])) 
probs = grad(logZ_fun) (logpots) 

To perform conditional inference, such as p(x, = k|a, = e), we multiply in one-hot indicator 
vectors to clamp x; to the value e so that the unnormalized joint only assigns non-zero probability to 
state combinations that are valid. We then sum over all values of the unclamped variables to get the 
constrained partition function Ze. The gradients will now give us the marginals conditioned on the 
evidence [Dar03]. 


9.6.2 Perturb and MAP 


In this section, we discuss how to draw posterior samples from a graphical model by leveraging 
optimization as a subroutine. The basic idea is to make S copies of the model, each of which has 
slightly perturbed versions of the parameters, 0, = 0, + €s, and then to compute the MAP estimate, 
£, = argmax p(a|y;@,). For a suitably chosen noise distribution for €s, this technique — known as 
perturb-and-MAP — can be shown that this gives exact posterior samples [PY10; PY11; PY 14]. 


9.6.2.1 Gaussian case 


We first consider the case of a Gaussian MRF. Let æ € RY be the vector of hidden states with prior 
1 
p(x) x N(Ga|p,, Ep) x exp(—5 T Kz + hia) (9.47) 


where G € REXY is a matrix that represents prior dependencies (e.g., pairwise correlations), 
K, = G'S) 'G, and h, = GE; ‘pw. Let y € R™ be the measurements with likelihood 


1 1 z 
p(ylæ) = N(y|Hz + c, En) x exp- 38' Kyt + hy, — 3y En y) (9.48) 


where H € RM XN represents dependencies between the hidden and visible variables, Kj = H'>,”'H 
and hys = H' X; '(y — c). The posterior is given by the following (c.f., one step of the information 


filter in Section 8.3.4) 


p(zly) = N (x|u, 2) (9.49) 
E = K = G' £; 'G + H' E, 'H (9.50) 
u =K(G' E'u, +H'E, (y —)) (9.51) 


4. See ugm_inf_autodiff.py for the full (JAX) code, and see https ://github.com/srush/ProbTalk for a (PyTorch) 
version by Sasha Rush. 
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where we have assumed K = K, + K,), is invertible (although the prior or likelihood on their own 
may be singular). 

The K rows of G = [g];...;g}-] and the M rows of H = {h];...;h},] can be combined into the 
L rows of F = [f];...; fL], which define the linear constraints of the system. If we assume that 
X, and X, are diagonal, then the structure of the graphical model is uniquely determined by the 
sparsity of F. The resulting posterior factorizes as a product of L Gaussian “experts”: 


L 


L 
1 
plæly) x J [exp -5a Kiz + hig) x [TN (fle: m, £1) (9.52) 
l=1 l=1 


where X}; equals X,7,; for l = 1 : K and equals Xn, for l = K +1: L where l =l— K. Similarly 
My = My, for l= 1: K and y = (yy — cy) for l= K +1: L. 

To apply perturb and MAP, we proceed as follows. First perturb the prior mean by sampling ñ, ~ 
N (Hp; Up), and perturb the measurements by sampling y ~ N (y, Sn). (Note that this is equivalent 


$ —1 
to first perturbing the linear term in each information form potential, using h; = hı + fiX, ° «1, 
where e ~ N (0,1).) Then compute the MAP estimate for x using the perturbed parameters: 


z= K G'E; A, +K "HS, (g - c) (9.53) 
= KGE (up + €.) + KIHTE, (y + €y — ©) (9.54) 
— — <—— a 
A B 
= u + Ae, + Be, (9.55) 


We see that E [%] = w and E | (č — ys)(# — u)" |] = K~* = F, so the method produces exact samples. 

This approach is very scalable, since compute the MAP estimate of sparse GMRFs (i.e., posterior 
mean) can be done efficiently using conjugate gradient solvers. Alternatively we can use loopy belief 
propagation (Section 9.3), which can often compute the exact posterior mean (see e.g., [WF Ola; 
JMW06; Bic09; Du+18]). 


31 9.6.2.2 Discrete case 


- In [PY11; PY14] they extend perturb-and-MAP to the case of discrete graphical models. This 
~ setup is more complicated, and requires the use of Gumbel noise, which can be sampled using 
~— € = —log(—log(w)), where u ~ Unif(0,1). This noise should be added to all the potentials in the 
-7 model, but as a simple approximation, it can just be added to the unary terms, i.e., the local evidence 


potentials. Let the score, or unnormalized log probability, of configuration x given inputs c be 


S(x; c) = log p(a|c) + const = 5 log ġi (xi) + 5 log Wi; (zij) (9.56) 
i ij 


41 where we have assumed a pairwise CRF for notational simplicity. If we perturb the local evidence 
42 potentials ¢;(k) by adding e;, to each entry, where k indexes the discrete latent states, we get 
43 Š (a;c). We then compute a sample % by solving = argmax 5 (x;c). The advantage of this 
44 approach is that it can leverage efficient MAP solvers for discrete models, such as those discussed in 
45 Supplementary Section 9.3. This can in turn be used for parameter learning, and estimating the 
46 partition function [HJ12; Erm+13]. 
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1 O Variational inference 


10.1 Introduction 


In this chapter, we discuss variational inference, which reduces posterior inference to optimization. 
Note that VI is a large topic; this chapter just gives a high level overview. For more details, see e.g., 
[Jor+98; JJO0; Jaa01; WJ08; Zha+19b; Bro18]. 


10.1.1 Variational free energy 


Consider a model with unknown variables z, known variables æ, and fixed parameters @. (If the 
parameters are unknown, they can be added to z.) Since computing the true posterior pg(z|x) is 
assumed intractable, we will use an approximation q(z), which we choose to minimize the following 
loss: 


q= a Dux (a(z) || po(z|æ)) (10.1) 


Since we are minimizing over functions (namely distributions q), this is called a variational method. 
In practice we pick a parametric family Q, where we use w to represent the variational parame- 
ters. We compute the best variational parameters (for given æ) as follows: 


yY“ = a Dri (a(z\) || pe(z|x)) (10.2) 

= argmin Ey(2jy log q(z|~) — log (zep) (10.3) 

= ia g(z|p) [log a(z|Y) — log pe(x|z) — log pe(z)] + log po (x) (10.4) 
—_—_————— o Iauna 


L(p|6,x) 


The final term logpe(x) = f pe(x,z)dz is generally intractable to compute. Fortunately, it is 
independent of a, so we can drop it. This leaves us with the first term, which we write using the 
following shorthand:! 


L(p|O, x) = Dux (4(2|) || po(x, z)) = Egy) [log alzi) — log po(a, z)] (10.5) 


1. Technically speaking Dz (q(z|~) || pe(a, z)) is not a KL divergence, since these are distributions over different 
spaces: the first is a conditional distribution over z, the second is a joint distribution over z and æ. However, hopefully 
this shorthand is clear. 
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P(z|x) 5 


“ KLZ; w*) I| plzlx)) 


Figure 10.1: Illustration of variational inference. The large oval represents the set of variational distributions 
Q = {q(z; p) : CR‘ }, where K is the number of variational parameters. The true distribution is the point 
p(z|x), which we assume lies outside the set. Our goal is to find the best approximation to p within our 
variational family; this is the point w* which is closest in KL divergence. We find this point by starting an 
optimization procedure from the random initial point '". Adapted from a figure by David Blei. 


If we define €(z) = — log pọ (z, x) as the energy, then we can write 


LIO, £) = Eac) [E(2)] — H@) (10.6) 


where H(q) is the entropy. In physics, this is known as the variational free energy. We can 
interpret this as the expected energy minus the entropy: 


VFE = expected energy — entropy (10.7) 


26 This is an upper bound on the free energy, — log pg(x), which follows from the fact that 


Dru (4 || p) = VFE(#|, x) + log pe(x) > 0 (10.8) 


Our goal is to minimize the VFE. See Figure 10.1 for an illustration. 


32 10.1.2 Evidence lower bound (ELBO) 
= The negative of the VFE is known as the evidence lower bound or ELBO function [BKM16]: 


L(p|6, £) = Egely) [log p(x, z) — log q(z|#)] (10.9) 
where log pg(x, z) is the unnormalized log joint. The name “ELBO” arises because 


L(~|6, x) < log pe(x) (10.10) 


=“ where log pg(a) is the called the “evidence”. The inequality follows from Equation (10.8). Therefore 
= maximizing the ELBO wrt w will decrease the original KL, since log pg(x) is a constant wrt w. 


(Note: we use the symbol Ł for the ELBO, rather than £, since the latter denotes a loss we want to 


22 minimize.) 
We can rewrite the ELBO as follows: 
L(p|@, x) = Ecel) [log po (x, z)] + H(g(z|p)) (10.11) 
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We can interpret this 
ELBO = expected log joint + entropy of posterior (10.12) 


The second term encourages the posterior to be maximum entropy, while the first term encourages it 
to be a joint MAP configuration, 
We can also rewrite the ELBO in the following equivalent way: 


z) + log pe(z) — log q(z|%)] (10.13) 
z)] — Dxx (a(2|*P) || po(z)) (10.14) 


L(w|8, x) = a(z) [log pe (a 
= Eq(z\p) [log po (æ 


We can interpret this as follows: 
ELBO = expected log likelihood — KL from posterior to prior (10.15) 


The KL term acts like a regularizer, preventing the posterior from diverging too much from the prior. 
(See also Section 21.2.2, where we discuss the ELBO in more detail, in the context of variational 
autoencoders. ) 


10.2 Mean field VI 


A common approximation in variational inference is to assume that all the latent variables are 
independent, i.e., 


J 
q(z|p) = [Lat (10.16) 


where J is the number of hidden variables, and q;(z;) is shorthand for qy,(zj), where wp, are the 
variational parameters for the j’th distribution. This is called the mean field approximation. 
From Equation (10.11), the ELBO becomes 


J 


L() = J a(z|yp) log po(æ, z)dz + X` H(4;) (10.17) 


j=1 


since the entropy of a product distribution is the sum of entropies of each component in the product. 
We can either directly optimize this (see e.g., [Baq+16]), or use a coordinate-wise optimization 
scheme, as we discuss in Section 10.2.1. 


10.2.1 Coordinate ascent variational inference (CAVI) 


We npw discuss a coordinate ascent method for optimizing the mean field objective, which we call 
coordinate ascent variational inference or CAVI. 

To derive the update equations, we initially assume there are just 3 discrete latent variables, to 
simplify notation. In this case the ELBO is given by 


3 
Ł(q1, 92,93) =X X at z1)q2(22)q3 (23) log p(z1, 22, 23) + >> Ha) (10.18) 
j=l 


Z1 Z2 23 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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where we define p(z) = pe(z, x) for brevity. We will optimize this wrt each q;, one at a time, keeping 
the others fixed. 
Let us look at the objective for q3: 


L3(q3) = 5 q3(z3) 5 5 qı (21)q2(22) log p(21, 22, z3) | + Hi(q3) + const (10.19) 


Z3 Z1 22 
= D q3 (23) [gs (23) — log q3 (z3)] + const (10.20) 
23 
where 
93(23) = X X an (21)a2(22) log B(a1, 22, 23) = Ex_, [log p(z1, 22, 23)] (10.21) 


Z1 22 


where z_3 = (21, 22) is all variables except z3. Here g3(z3) can be interpreted as an expected negative 
energy (log probability). We can convert this into an unnormalized probability distribution by 
defining 


fs(23) = exp(ga(zs)) (10.22) 
which we can normalize to get 
hleo) = EE x explo) (10.23) 


Dy fla) 
Since g3(z3) x log f3(z3) we get 


L3(q3) = X e(z) [log f3 (23) — log q3(23)] + const = —Dxx (qs || fs) + const (10.24) 


31 Since Dx (q3 || f3) achieves its minimal value of 0 when q3(z3) = f3(z3) for all z3, we see that 
32 93(23) = f3(za). 


Now suppose that the joint distribution is defined by a Markov chain, where z1 > z2 — 23, so 


zı L z3|z2. Hence log p(z1, 22, z3) = log p(z2, z3|z1) + log p(z1), where the latter term is independent 
35 of q3(z3). Thus the ELBO simplifies to 
L3(q3) = X as(zs) £ q2(22) log p(z2, z3) | + H(q3) + const (10.25) 
Z3 Z2 
= 5 qs (z3) [log fs(z3) — log q3(23)] + const (10.26) 
23 
~- Where 
fs(z3) x exp £ q2(22) log (22, a) = exp [Ez,,,, [log (22, z3)]] (10.27) 
22 


A 16 | 
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Zi Z2 
T2 
23 Z4 


3@ T4 


Figure 10.2: A grid-structured MRF with hidden nodes z; and local evidence nodes xi. The prior p(z) is an 
undirected Ising model, and the likelihood p(x|z) = J J; p(xi|zi) is a directed fully factored model. 


where mb3 = (22) is the Markov blanket (Section 4.2.4.3) of z3. As before, the optimal variational 
distribution is given by q3(z3) = f3(z3). 

In general, when we have J groups of variables, the optimal variational distribution for the 7’th 
group is given by 


qj (2j) x exp | Damb; [log B(z;, zm»; )]| (10.28) 


(Compare to the equation for Gibbs sampling in Equation (12.19).) The CAVI method simply 
computes q; for each dimension j in turn, in an iterative fashion (see Algorithm 20). Convergence is 
guaranteed since the bound is convex wrt each of the factors q; [Bis06, p. 466]. 


Algorithm 20: Coordinate Ascent Variational Inference (CAVI). 
1 Initialize q;(z;) for j =1: J 
2 foreach t= 1 : T do 
3 foreach j = 1: J do 
| Compute Jj (z3) = Leb; [log p( zi, Zmb; )] 
Compute qj(2;) x exp(g;(z;)) 


a fF 


Note that the functional form of the q; distributions does not need to be specified in advance, but 
will be determined by the form of the log joint. This is therefore called free-form VI, as opposed to 
fixed-form, where we explicitly choose a convenient distributional type for q (we discuss fixed-form 
VI in Section 10.3). We give some examples below that will make this clearer. 


10.2.2 Example: CAVI for the Ising model 


In this section, we apply CAVI to perform mean field inference in an Ising model (Section 4.3.2.1), 
which is a kind of Markov random field defined on binary random variables, z; € {—1, +1}, arranged 
in a 2d grid. 

Originally Ising models were developed as models of atomic spins for magnetic materials, although 
we will apply them to an image denoising problem. Specifically, let z; be the hidden value of pixel i, 
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and x; € R be the observed noisy value. See Figure 10.2 for the graphical model. 
Let Li(z;) = log p(a;|z;) be the log likelihood for the ith pixel (aka the local evidence for node i 
in the graphical model). The overall likelihood has the form 


p(a|z) = IED = exp() | L;(z)) (10.29) 


Our goal is to approximate the posterior p(z|a). We will use an Ising model for the prior: 


p(z) = z exp(—Eo(z)) (10.30) 


where we sum over each 7 — j edge. Therefore the posterior has the form 


p(zle) = = exp(-E(z)) (10.32) 


Z(x) 
E(z) = Eo(z) — a Li (2) (10.33) 


We will now make the following fully factored approximation: 


alz) = |] ile) = [Beliu (10.34) 


where u; = Eg, [zi] is the mean value of node i. To derive the update for the variational parameter pu, 


49 Zi 
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27 we first compute the unnormalized log joint, log p(z) = —€(z), dropping terms that do not involve 
log p(z) = zi 5 Wijzj + Lilzi) + const (10.35) 
jEnbr; 
2a This only depends on the states of the neighboring nodes. Hence 
giz) x exp(Ey_,(x) log (2)]) = exp | zi SD Wages + Lila) (10.36) 
jEnbri 


38 where q—:(z) = [ [jz 4(z;). Thus we replace the states of the neighbors by their average values. 


39 (Note that this replaces binary variables with continuous ones.) 
We now simplify this expression. Let m; = > jenbr, Wig by be the mean field influence on node i. 
4 Also, let L{ = L;(+1) and L7 = L;(—1). The approximate marginal posterior is given by 
emit le 1 
qi (% = 1) = emi+L? ae en mith, = 1 4 e-2mi+L7 -L} = a(2a;) (10.37) 
ai =m; +0.5(L7 — L7 ) (10.38) 
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sample 1, meanfieldH 


mean after 15 sweeps of meanfieldH 


(c) 


Figure 10.3: Example of image denoising using mean field (with parallel updates and a damping factor of 0.5). 
We use an Ising prior with Wi; = 1 and a Gaussian noise model with o = 2. We show the results after 1, 3 
and 15 iterations across the image. Compare to Figure 12.3, which shows the results of using Gibbs sampling. 
Generated by ising image_denoise_demo.ipynb. 


Similarly, we have q;(z; = —1) = o (—2a;). From this we can compute the new mean for site i: 
pi = Eg; [zi] = alee = +1) - (+1) + ala =—-D-(-) (10.39) 
1 1 ai -ai 
= 2 F = tanh(a;) (10.40) 


—2a; 2a; ay — üi ay ay 
+e +e ere ea a 


We can turn the above equations into a fixed point algorithm by writing 


wi =tanh| XO Wii) +0.5(LP — L7) (10.41) 


jEnbri 


Following |MWJ99], we can use damped updates of the following form to improve convergence: 


we = (1— A) + Atanh | XO Wau’ +0.5(L} — L7) (10.42) 


jEnbri 


for 0 << A< 1. We can update all the nodes in parallel, or update them asychronously. 

Figure 10.3 shows the method in action, applied to a 2d Ising model with homogeneous attractive 
potentials, W;; = 1. We use parallel updates with a damping factor of \ = 0.5. (If we don’t use 
damping, we tend to get “checkerboard” artefacts.) 


10.2.3 Variational Bayes 


In Bayesian modeling, we treat the parameters @ as latent variables. Thus our goal is to approximate 
the parameter posterior p(@|D) « p(@)p(D|@). Applying VI to this problem is called variational 
Bayes [Att00]. 

In this section, we assume there are no latent variables except for the shared global parameters, so 
the model has the form 


N 
p(0, D) = p(0) | | p(Pnl@) (10.43) 
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(a) 


Figure 10.4: Graphical models with (a) Global hidden variable Os and observed variables x1:n. (b) Local 
hidden variables z1:n, global hidden variables 0,,0., and observed variables x1:n. 


These conditional independencies are illustrated in Figure 10.4a. 
We will fit the variational posterior by maximizing the ELBO 


L(|D) = Eq(oly,) [log p(8, D)] + H(q(A|%ho)) (10.44) 


We will assume the variational posterior factorizes over the parameters: 


g(9\%o) = [ [ a(jlo,) (10.45) 
j 
We can then update each pg, using CAVI (Section 10.2.1). 


10.2.4 Example: VB for a univariate Gaussian 


= Consider inferring the parameters of a 1d Gaussian. The likelihood is given by p(D|@) = 1 N (nlu, 71), 
— where u is the mean and A is the precision. Suppose we use a conjugate prior of the form 


plu, A) = N (ulno, (KoA) *)Ga(Alao, bo) (10.46) 


It is possible to derive the posterior p(u, A|D) for this model exactly, as shown in Section 3.2.3.3. 
However, here we use the VB method with the following factored approximate posterior: 


qalu, A) = aul, arp) (10.47) 


44 We do not need to specify the forms for the distributions q(u|y,,) and q(A|w,); the optimal forms 
45 will “fall out” automatically during the derivation (and conveniently, they turn out to be Gaussian 
46 and Gamma respectively). Our presentation follows [Mac03, p429]. 
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10.2.4.1 Target distribution 


The unnormalized log posterior has the form 
log p(u, A) = log p(w, A, D) = log p(D|p, A) + log p(u|A) + log p(A) 


N: AÀ KoA 
= 2R log À a X (an u} 3 (u wy 


1 
+ 5 log(KoA) + (ao — 1) log A — bp. + const 


10.2.4.2 Updating q(u|%,,) 


The optimal form for ¢(u|%),,) is obtained by averaging over À: 


log qu.) = Egal, ) [log p(D| i, A) + log p(u|A)] + const 


x À Np 
= eae {rot = po)? + So (an - “| + const 
n=1 


(10.48) 


(10.49) 


(10.50) 


(10.51) 


By completing the square one can show that q(u|p,,) = N (ulin, Ky), where 


_ Kolo + Npt 
ENS Ko + No 


, kN = (ko + Np)Eqaly,) [À] 


At this stage we don’t know what q(A|¢))) is, and hence we cannot compute 
this below. 


10.2.4.3 Updating q(A|?,) 


The optimal form for q(A|q,) is given by 


log aAa) = Equuy,,) log p(Plu, A) + log p(u|A) + log p(A)] + const 


1 N. 
= (ao = 1) log A ~ bo + 5 log À 4 5 log 


No 
À 
= g Paule) fot — uo)? + X (£n — u)? | + const 
n=1 


(10.52) 


z [A], but we will derive 


(10.53) 


(10.54) 


We recognize this as the log of a Gamma distribution, hence q(A|¢,) = Ga(Alan, bn), where 


N+1 
an = A + —>— 
2 
1 An 
bn = bo + SE quly,,) pa -= po)? + So (an — 2)? 
n=1 
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10.2.4.4 Computing the expectations 


To implement the updates, we have to specify how to compute the various expectations. Since 
q(u) = N (ulun, Ky), we have 


a(n) LH] = uN (10.57) 
1 

J 2] = — a 10.58 

a) [e] = gy TAN (10.58) 


ta [A] = > (10.59) 


We can now give explicit forms for the update equations. For q(,:) we have 


Kolo + NDT 


aes 10. 
Ne Kg + Np 51060) 
rN = (ko + Np) X (10.61) 
bn 
and for g(A) we have 
N +i 
eae ee (10.62) 
1 ic 
by =bo + 5 fol à [u?] + u? — 2E [u] wo) + 5 5 (22 +E [| — 2E [u] £n) (10.63) 
n=1 


We see that uy and ay are in fact fixed constants, and only ky and by need to be updated 


= iteratively. (In fact, one can solve for the fixed points of ky and by analytically, but we don’t do 
= this here in order to illustrate the iterative updating scheme.) 


-~ 10.2.4.5 Illustration 


A Je Jẹ Je JA Je Je TR Jw [ww jœ lw jw jw jw jw 
A IÀ Ià IÈ [è IS IF 1S IS le IS IS IS IE IS Is 


Figure 10.5 gives an example of this method in action. The green contours represent the exact 


34 posterior, which is Gaussian-Gamma. The dotted red contours represent the variational approximation 


over several iterations. We see that the final approximation is reasonably close to the exact solution. 


36 However, it is more “compact” than the true distribution. It is often the case that mean field inference 
37 underestimates the posterior uncertainty, for reasons explained in Section 5.1.3.3. 


10.2.4.6 Lower bound 


In VB, we maximize a lower bound on the log marginal likelihood: 


E(WalD) < log p(D) = log 1 | p(D]u, A)plu, Add (10.64) 


It is very useful to compute the lower bound itself, for three reasons. First, it can be used to assess 
convergence of the algorithm. Second, it can be used to assess the correctness of one’s code: as with 
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—1.0 —0.5 0.0 0.5 1. —1.0 —0.5 0.0 0.5 1 
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Figure 10.5: Factored variational approximation (red) to the Gaussian-Gamma distribution (green). (a) 
Initial guess. (b) After updating q(u|p,,). (c) After updating q(A\b,). (d) At convergence (after 5 iterations). 
Adapted from Fig. 10.4 of [Bis06]. Generated by unigauss_ vb_ demo.ipynb. 


EM, if we use CAVI to optimize the objective, the bound should increase monotonically at each 
iteration, otherwise there must be a bug. Third, the bound can be used as an approximation to the 
marginal likelihood, which can be used for Bayesian model selection. One can show that the lower 
bound has the following form: 


1 1 
L = const + = In — + lnT (ayn) — an nbn (10.65) 
2 KN 


10.2.5 Variational Bayes EM 


In Bayesian latent variable models, we have two forms of hidden variables: local (or per example) 
hidden variables zn, and global (shared) hidden variables 0, which represent the parameters of the 
model. See Figure 10.4b for an illustration. (Note that the parameters, which are fixed in number, 
are sometimes called intrinsic variables, whereas the local hidden variables are called extrinsic 
variables.) If h = (0, 21.) represents all the hidden variables, then the joint distribution is given 
by 


N 
p(h, D) = p(0, 21:n,D) = p(0) [| p(2nl9)p(@nl2n; 0) (10.66) 
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We will make the following mean field assumption: 


N 
qa(0, 21:n|Pi.ws Po) = (lo) || enitn) (10.67) 


n=1 


Henceforth we will usually omit the variational parameters (¢#.,, Yọ) for brevity. 
We will use VI to maximize the ELBO: 


L(y, Wo|D) = §q(0,z1.n |W 1.0 Pe) flog p(z1:N, 0, D) = log q(0, Z1:n)| (10.68) 


If we use the mean field assumption, then we can apply the CAVI approach to optimize each set of 
variational parameters. In particular, we can alternate between optimizing the qn(Zn) in parallel, 
independently of each other, with q(0) held fixed, and then optimizing q(@) with the qn held fixed. 
This is known as variational Bayes EM [BG06]. It is similar to regular EM, except in the E step, 
we infer an approximate posterior for z, averaging out the parameters (instead of plugging in a point 
estimate), and in the M step, we update the parameter posterior parameters using the expected 
sufficient statistics. 

Now suppose we approximate q(0) by a delta function, q(@) = 6(0 — 6). The Bayesian LVM ELBO 
objective from Equation (10.68) simplifies to the the “LVM ELBO”: 


Elpin OD) = alenn lay) [log p(8, D, 21.) — log q(21:nlY1.n)] (10.69) 


We can optimize this using the variational EM algorithm, which is a CAVI algorithm which updates 
the w,, in parallel in the variational E step, and then updates 0 in the M step. 

VEM is simpler than VBEM since in the the variational E step, we compute q(Zn|£n, 6), instead 
(6(¢(Zn|@n, O)|; that is, we plugin a point estimate of the model parameters, rather than averaging 


La) 


(0) 


2° over the parameters. For more details on VEM, see Section 6.6.6.1. 


~ 10.2.6 Example: VBEM for a GMM 


31 Consider a standard Gaussian mixture model (GMM): 


p(z, x|0) = TTT Enl Hp, Ay) (10.70) 


~~ where Znk = 1 if data point n belongs to cluster k, and z,, = 0 otherwise. Our goal is to approximate 
~ the posterior p(z,@|x) under the following conjugate prior 


p0) = Dir(x| & TM u| M, (K Ap)~!)Wi(A,| L, 7) (10.71) 


41 where A, is the precision matrix for cluster k. For the mixing weights, we usually use a symmetric 
42 prior, @= aol. 


The exact posterior p(z,@|D) is a mixture of K distributions, corresponding to all possible 


44 labelings z, which is intractable to compute. In this section, we derive a VBEM algorithm, which 


will approximate the posterior around a local mode. We follow the presentation of [Bis06, Sec 10.2]. 


46 (See also Section 10.3.5.3, where we discuss a different variational approximation for this model.) 
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10.2. MEAN FIELD VI 


10.2.6.1 The variational posterior 


We will use the standard mean field approximation to the posterior: q(0, 21:v) = q(9) [],, dn(Zn)- At 
this stage we have not specified the forms of the q functions; these will be determined by the form of 
the likelihood and prior. Below we will show that the optimal forms are as follows: 


Gn (a) = Cat(zn|rn) (10.72) 
q(@) = Dir(ar| &) | [Níu] Pir, (Re Av) *)Wi(Ag| Ex, Px) (10.73) 
k 


where r, are the posterior responsibilities, and the parameters with hats on them are the hyperpa- 
rameters from the prior updated with data. 
10.2.6.2 Derivation of q(@) (variational M step) 


Using the mean field recipe in Algorithm 20, we write down the log joint, and take expectations over 
all variables except 0, so we average out the z, wrt q(Z,) = Cat(z,|rn): 


log q(9) = log p(m) + 5 Salen) [log p(Zn|7)] 


Lr 


+ 5 log p(uk, Ax) 5 Lo(zn) [znk] log N (£n| Hk, Az’) + const (10.74) 
k n 


Lup Ap 


Since the expected log joint factorizes into a term involving m and terms involving (Hp, Ax), we see 
that the variational posterior also factorizes into the form 


q(8) = a(r) | [ alu, An) (10.75) 
k 


For the m term, we have 
log g(a) = (ao — 1) 5 log Tk + 5 5 Tnk log Tk + const (10.76) 
k k n 


Exponentiating, we recognize this as a Dirichlet distribution: 
qlr) = Dir(z| &) (10.77) 
Qk = ao + Nk (10.78) 
Ne =X fnk (10.79) 


n 
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1 

2 For the js; and A; terms, we have 

3 A a = . OS. ps 

T Cbs Ak) =N (Hg| Fi, (Re Ak) *)Wi(Ax| Le, Pr) (10.80) 
5 Rk =K +Nk (10.81) 
6 my, = (K M +NkTk)/ Rk (10.82) 
í ASi a k N 

3 £, =L +N,8, + —— (Ep M) (Ep m)" (10.83) 
A +Nk 

id Dk =Ù +Nk (10.84) 
aa, 1 

1 = 

11 oe nkEn 10.85 
F Lk N, Tnkt ( ) 
13 1 T 

4 Ss, = — Tnk(En = Lr )(Ln Eg Tp) (10.86) 
14 N; 

15 n 

16 This is very similar to the M step for MAP estimation for GMMs, except here we are computing 
17 the parameters of the posterior for 0 rather than a point estimate 0. 

18 

+2 10.2.6.3 Derivation of q(z) (variational E step) 

20 

21 The variational E step is more interesting, since it is quite different from the E step in regular EM, 
22 because we need to average over the parameters, rather than condition on them. In particular, we 
23 have 

24 1 D 
2 logqg(z) =X X enk ( 1a(m) llog Tk] + 5 Eacan) log Akl] — 5 los(27) 
26 n k 
27 Les T 
F —5Eq@) [(an — Hg)! Ak(En — Hg)| } + const (10.87) 
29 
30 Using the fact that g(a) = Dir(m| @), one can show that 

31 A 
S ; exp(¥(@x)) az 
32 exp(E,x) [log 7]) = = = ip (10.88) 
* oe exp(U (Dog x’) 
34 where w is the digamma function 

35 
aa d 
36 (x) = — logT (x) (10.89) 
37 dx 
38 This takes care of the first term. 
39 For the second term, one can show 
40 
41 2 k +1— 
a2 Ea Log |Ax|] = Sob on k 1) + Diog2 + 1og Lx | (10.90) 
43 I 
44 Finally, for the expected value of the quadratic form, one can show 
45 
46 Egu, Ar) (En — Me) Ak(En — Hp)] = D R + Pp (@n— Mx)" Er (@n— Mee) Ê Åk (10.91) 
47 
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20! 


= 3 


16) | — exp(W(-)) 


1.2 


08 


0.4 


(a) 


Figure 10.6: (a) We plot exp(y(x)) vs x. We see that this function performs a form of shrinkage, so that 
small values get set to zero. (b) We plot Nz vs time for 4 different states (z values), starting from random 
initial values. We perform a series of VBEM updates, ignoring the likelihood term. We see that states that 
initially had higher counts get reinforced, and sparsely populated states get killed off. From [LK07]. Used with 
kind permission of Percy Liang. 


Thus we get that the posterior responsibility of cluster k for datapoint n is 


sä D D A 
Tnk X pA? exp ( a 3 (x,— Pip)" Ay (an— fn) (10.92) 


Compare this to the expression used in regular EM: 
j ies 1 x N ‘ 
PEM o lel? exp (—5 (arn = ig)" An(@n ~ a) (10.93) 
where 7; is the MAP estimate for mg. The significance of this difference is discussed in Section 10.2.6.4. 


10.2.6.4 Automatic sparsity inducing effects of VBEM 


In regular EM, the E step has the form given in Equation (10.93), whereas in VBEM, the E step has 
the form given in Equation (10.92). Although they look similar, they differ in an important way. To 
understand this, let us ignore the likelihood term, and just focus on the prior. From Equation (10.88) 
we have 


exp(%(@x)) 
exp(Y (>w Ger) 


And from the usual EM MAP estimation equations for GMM mixing weights (see e.g., [Mur22, Sec 
8.7.3.4]) we have 


VB = 
Tak ~ Tk = 


(10.94) 


(10.95) 


where Qk= ao + Nk, and Nk = DDA Tnk is the expected number of assignments to cluster k. 
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We know from Figure 2.9 that using ag < 1 causes m to be sparse, which will encourage rn to be 
sparse, which will “kill off” unnecessary mixture components (i.e., ones for which Ng << N, meaning 
very few data points are assigned to cluster k). To encourage this sparsity promoting effect, let us 
set ag = 0. In this case, the updated parameters for the mixture weights are given by the following: 


~ _ exp((Nz)) 
te SCS p Ned) Sa 
it, M= (10.97) 


Bone 


Now consider a cluster which has no assigned data, so Nk = 0. In regular EM, 7, might end up 
negative, as pointed out in [FJ02]. (This will not occur if we use maximum likelihood training, which 
corresponds to ao = 1, but this will not induce any sparsity, either.) This problem does not arise in 
VBEM, since we use the digamma function, which is always positive, as shown in Figure 10.6(a). 

More interestingly, let us consider the effect of these updates on clusters that have unequal, 
but non-zero, number of assignments. Suppose we start with a random assignment of counts to 4 
clusters, and iterate the VBEM algorithm, ignoring the contribution from the likelihood for simplicity. 
Figure 10.6(b) shows how the counts Nj, evolve over time. We notice that clusters that started out 
with small counts end up with zero counts, and clusters that started out with large counts end up 
with even larger counts. In other words, the initially popular clusters get more and more members. 
This is called the rich get richer phenomenon; we will encounter it again in Section 31.3, when we 
discuss Dirichlet process mixture models. 

The reason for this effect is shown in Figure 10.6(a): we see that exp(¢(N;,)) < Ng, and is zero if 
Nx is sufficiently small, similar to the soft-thresholding behavior induced by ¢1-regularization (see 
Section 15.2.5). Importantly, this effect of reducing N; is greater on clusters with small counts. 

We now demonstrate this automatic pruning method on a real example. We fit a mixture of 6 


-Z Gaussians to the Old Faithful dataset, using œo = 0.001. Since the data only really “needs” 2 clusters, 


the remaining 4 get “killed off”, as shown in Figure 10.7. In Figure 10.8, we plot the initial and final 
values of œg; we see that @,= 0 for all but two of the components k. 

Thus we see that VBEM for GMMs with a sparse Dirichlet prior provides an efficient way to choose 
the number of clusters. Similar techniques can be used to choose the number of states in an HMM 
and other latent variable models. However, this variational pruning effect (also called posterior 
collapse), is not always desirable, since it can cause the model to “ignore” the latent variables z if 
the likelihood function p(a|z) is sufficiently powerful. We discuss this more in Section 21.4. 


10.2.6.5 Lower bound on the marginal likelihood 


= The VBEM algorithm is maximizing the following lower bound 


x,z,0 
C= | dO al, 0) 10g A < lose (10.98) 


46 This quantity increases monotonically with each iteration, as shown in Figure 10.9. 
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iter 1 iter 94 


Figure 10.7: We visualize the posterior mean parameters at various stages of the VBEM algorithm applied 
to a mixture of Gaussians model on the Old Faithful data. Shading intensity is proportional to the mixing 
weight. We initialize with K-means and use ao = 0.001 as the Dirichlet hyper-parameter. (The red dot on the 
right panel represents all the unused mixture components, which collapse to the prior at 0.) Adapted from 
Figure 10.6 of [Bis06]. Generated by gmm_vb_ em.ipynb. 


iter 1 iter 94 


Figure 10.8: We visualize the posterior values of a, for the model in Figure 10.7 after the first and last 
iteration of the algorithm. We see that unnecessary components get “killed off”. (Interestingly, the initially 
large cluster 6 gets “replaced” by cluster 5.) Generated by gmm_vb_ em.ipynb. 


variational Bayes objective for GMM on old faithful data 


lower bound on log marginal likelihood 
i 
& 
3 


Figure 10.9: Lower bound vs iterations for the VB algorithm in Figure 10.7. The steep parts of the curve 
correspond to places where the algorithm figures out that it can increase the bound by “killing off” unnecessary 
miature components, as described in Section 10.2.6.6. The plateaus correspond to slowly moving the clusters 
around. Generated by gmm_vb_ em.ipynb. 
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10.2.6.6 Model selection using VBEM 


Section 10.2.6.4 discusses a way to choose K automatically, during model fitting, by “killing off” 
unneeded clusters. An alternative approach is to fit several models, and then to use the variational 
lower bound to the log marginal likelihood, L(K) < log p(D|K), to approximate p(K|D). In particular, 
if we have a uniform prior, we get the posterior 

p(D|K) a 


p(K|D) = Ye POKI) eR (10.99) 


It is shown in [BG06] that the VB approximation to the marginal likelihood is more accurate than 
BIC [BG06]. However, the lower bound needs to be modified somewhat to take into account the 
lack of identifiability of the parameters. In particular, although VB will approximate the volume 
occupied by the parameter posterior, it will only do so around one of the local modes. With K 
components, there are K! equivalent modes, which differ merely by permuting the labels. Therefore a 
more accurate approximation to the log marginal likelihood is to use log p(D|K) ~ L(K) + log(K!). 


10.2.7 Variational message passing (VMP) 


In this section, we describe the CAVI algorithm for a generic model in which each complete conditional, 
p(z;|2_;, Œ), is in the exponential family, i.e., 


p(z;|z~;,”) = h(z;) exp[n;(z-5, £) T(z) — 43(nj (2-3, 2))] (10.100) 


where 7 (z;) is the vector of sufficient statistics, n; are the natural parameters, A; is the log partition 
function, and h(z;) is the base distribution. This assumption holds if the prior p(z,;) is conjugate to 
the likelihood, p(z_;, x|z;). 

If Equation (10.100) holds, the mean field update node j becomes 


qj (2;) x exp [E [log p(z;|2—7, &)]] (10.101) 
= exp [log a(z) + 3 [n; (2-5, @)]" T(z;) - t[As(n,(2-5,2))] | (10.102) 
x h(z;) exp [E [n,(2-j,2)]" T(z] (10.103) 


— Thus we update the local natural parameters using the expected values of the other nodes. These 


FR BIE IRI IS IB IB IS I$ IS IS IS 
NID [OUTER [we IN TF IO [© |e IN [mo [or 


— become the new variational parameters: 


Y; =E [nj(2-;,@)| (10.104) 


We can generalize the above approach to work with any model where each full conditional is 
conjugate. The resulting algorithm is known as variational message passing or VMP [WB05] 


41 that works for any directed graphical model. VMP is similar to belief propagation (Section 9.2): at 
42 each iteration, each node collects all the messages from its parents, and all the messages from its 
43 children (which might require the children to get messages from their co-parents), and combines them 
44 to compute the expected value of the node’s sufficient statistics. The messages that are sent are the 
45 expected sufficient statistics of a node, rather than just a discrete or Gaussian distribution (as in BP). 
46 Several software libraries have implemented this framework (see e.g., |Win; Min+18; Lut16; Wan17]). 
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10.3. FIXED-FORM VI 


VMP can be extended to the case where each full conditional is conditionally conjugate using the 
CVI framework in Supplementary Section 10.3.1. See also [ABV21], where they use local Laplace 
approximations to intractable factors inside of a message passing framework. 


10.2.8 Autoconj 


The VMP method requires the user to manually specify a graphical model; the corresponding node 
update equations are then computed for each node using a lookup table, for each possible combination 
of node types. It is possible to automatically derive these update equations for any conditionally 
conjugate directed graphical model using a technique called autoconj [HJT18]. This is analogous to 
the use of automatic differentiation (autodiff) to derive the gradient for any differentiable function. 
(Note that autoconj uses autodiff internally.) The resulting full conditionals can be used for CAVI, 
and also for Gibbs sampling (Section 12.3). 


10.3 Fixed-form VI 


Recall that the goal of variational inference is to maximize the following lower bound wrt the 
variational parameters w: 


= Ee flog rO] pg eye 

LID) = Eq, (2) h ea |- lay: a2) | (10.105) 
where 

lyp(z) = log p(z, D) — log qy (z) (10.106) 


Here z are the unknown latent variables (or parameters), and w are the variational parameters. In 
the mean field method of Section 10.2, we assumed that q(z) factorized across (groups of) variables, 
q(z|b) = Me q;(z;). We do not specify the form of qj, so this approach is called “free-form” VI; 
however, the form of the optimal q; can be derived analytically. We then optimized each q; using 
coordinate ascent. 

In this section, we take a different approach: we pick any convenient form we like for q(z), such as 
a multivariate Gaussian, and then we directly maximize the ELBO using gradient ascent. This is 
called fixed-form VI. 


10.3.1 Stochastic variational inference 


In many models, the likelihood factorizes into a product of terms, in which case we have 


N 
ly(z) = be log p(an|z)| + log p(z) — log qy (z) (10.107) 
n=1 
If N is large, we can compute an unbiased minibatch approximation to this expression as follows: 
R N È 
ty (z) = 5 Y log p(as|z)| + log p(z) — log qy (2) (10.108) 
b=1 
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We can use this to create an unbiased Monte Carlo approximation to the ELBO: 


L(wID) = Egy [év(2)] (10.109) 


This is called stochastic variational inference or SVI [Hof+13], and allows VI to scale to large 
datasets. 


10.3.2 Black-box variational inference 


In this section, we assume that we can evaluate f(z) pointwise, but we do not assume we can take 
gradients of this function. (For example, z may contain discrete variables.) We are thus treating the 
model as a “blackbox”. Hence this approach is called black box variational inference or BBVI 
[RGB14; ASD20]. 

To estimate the gradient of the ELBO, we will use the score function estimator, also called 
the REINFORCE estimator (Section 6.5.3). To derive this, we use the fact that Vlog q = 2 to 
conclude that Vq = qV logg. (This is called the log derivative trick.) We also exploit the fact 
that f q(z)dz = 1. With this, the gradient of the ELBO can be derived as follows: 


_ Alos PED) ae 

VyL) = Vy | ap(2) log ER (10.110) 
= A Noo 2P] ate 7 P(z)p(Plz)] 4, 
= f Wyasa) [lox =D] ae + f tayta) [Vy toe ER] a (10.111) 
=} Vb Up (2) ios” E 2| dz IO log qy (z)dz (10.112) 
=| Vb Gp (2) ios” a 2] dz J Vuez (10.113) 
af Vagy (z) tog ELTE] ae Vy | a(2)dz (10.114) 
_ 2) llos P z)p(D]|z) ? 
=] Vudu (Zz) ow = a |a (10.115) 
= f asl) Vy ogay (z) ty(z)) dz (10.116) 
= Eqy(z) [Vy log qy (2) ly (2)] (10.117) 

We can compute a stochastic approximation to this gradient by sampling Zs ~ qy(z) and then 

computing 
ou ea 
Veh) = z >, Vy log dy (zs) Cy (zs)ly=y, (10.118) 


We can pass this to any kind of gradient optimizer, such as SGD or Adam. If we further approximate 
ly(Zs) using the SVI minibatch approxoimation, we get a “doubly stochastic” approximation 
[TLG14]. 
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10.3. FIXED-FORM VI 


In practice, the variance of this estimator is quite large, so it is important to use methods such 
as control variates or CV (Section 6.5.3.1). To see how this works, consider the naive gradient 
estimator in Equation (10.118), which for the ith component we can write as 


’ S 

Voth) = 3 Pilz) (10.119) 
s=1 

Gi(2s) = gilZs) X lp (Zs) (10.120) 

Gi(Zs) = Vp, log dp (Zs) (10.121) 


The control variate version of this can be obtained by replacing g;(z,) with 


gi (2) = giz) + ci (E [bi(z)] — bi(z)) (10.122) 


where b;(z) is a baseline function and c; is some constant, to be specified below. A convenient 
baseline is the score function, b;(z) = Vp, log qẹ, (2), since this is correlated with g;(z), and has the 
property that E [b;(z)] = 0, since the expected value of the score function is zero, as we showed in 
Equation (2.245). Hence 


If (2) = Gi(z) — cagi(z) = gi(z)(Ep(z) — ci) ites) 


so the CV estimator is given by 


Vpt) = DI (zs) x (Lp(Zzs) — ci) (10.124) 


One can show that the optimal c; that minimizes the variance of the CV estimator is 


Cov [gi(z zly (z), gi(z)] 
V [9:(2)] 


which can be estimated by sampling z ~ qy(z). Thus the overall algorithm is as shown in Algo- 
rithm 21. 

We can stop the algorithm when the lower bound stops increasing. We can compute a stochastic 
approximation to the lower bound using 


ĉi = 


(10.125) 


S 
L(p) = $ Se) (10.126) 


where Zs ~ q(Z). To smooth out the noise, we can use a running average over the last w observations 
to get 


L(t.) = > Elbira) (10.127) 


1 
w 


Il 
m 


If the moving average does not improve after P consecutive iterations, we declare convergence, where 
P is the patience parameter. Typical values are P = 20 and w = 20 [TND21]. 
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Algorithm 21: Blackbox VI with control variates 
1 Initialize wo 
2 Zs ~ dy, (Z),8=1:8 
3 hy, (2s) = log p(zs, D) E log qy, (Zs), s=1:5 
S 
4 Jo = = ost [Vy log dy, (zs) Rap, (2s) 
5 Compute cı using Equation (10.125) applied to zs 
6 for t=1:T do 
7 Zs ~ Qy,(z), s=1:8 
8 hy, (zs) = log p(zs, D) — log qy, (zs), 5s=1: 5 
S 
9 | g= 4 sai [Vw log qy, (2s)] Ohy, (Zs) — c+) 
10 Compute c;,1 using Equation (10.125) applied to zs 
11 a, = gradient-update(?,_1, 9+) 
10.3.3 Reparameterization VI 
In this section, we exploit the reparameterization trick from Section 6.5.4 to get a lower variance 
estimator for the gradient. This assumes that ¢y(z) is a differentiable function of z. It also assumes 
that we can sample z ~ qy (z) by first sampling a noise term € ~ qo(€), and then transforming it to 
compute the latent random variables z = r(p, €). In this case, the ELBO becomes 
LP) = Ego ce) [Cw (rh, €))] = Egoe) [log pr, €), D) — log qy (rp, €))] (10.128) 
31 Since the sampling distribution qo(e€) is independent of the variational parameters yw, we can push 
32 the gradient operator inside the expectation, and thus we can estimate the gradient using standard 
33 automatic differentiation methods, as shown in Algorithm 22. This is called reparameterized VI 
or RVI, and has provably lower variance than BBVI in certain cases [Xu+19]. 
Algorithm 22: Estimate of ELBO gradient 
38 1 def elbo(w): 
= 2 e ~ qo(€) 
— 3 z= ry, €) 
=a — returnlogp(z, D) — qy(z|D) 
~ 5 def elbo-grad (4): 
~ 6 return grad(elbo(<)) 
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10.3. FIXED-FORM VI 


10.3.3.1 “Sticking the landing” estimator 


Applying the results from Section 6.5.4.2, we can derive the gradient estimate of the reparameterized 
ELBO, for a single Monte Carlo sample, as follows: 


Vb, €) = Vy [log p(z, D) — log qy (z|D)] (10.129) 
= Vz [log p(z|D) — log qy(z|D)] J — Vy log ay (z|D) (10.130) 
eae TN 
path derivative score function 


where z = r(,e€) and J = Vyr(~,€) is the Jacobian matrix of the noise transformation. 

The first term is the indirect effect of w% on the objective via the generated samples z. The second 
term is the direct effect of 4 on the objective. The second term is zero in expectation since it is 
the score function (see Equation (2.245)), but it may be non-zero for a finite number of samples, 
even if ¢y(z|D) = p(z|D) is the true posterior. In a paper called “sticking the landing”, [RWD17| 
propose to drop the second term to create a lower variance estimator.” In practice, this means we 
compute the gradient using Algorithm 23 instead of Algorithm 22.° 


Algorithm 23: “Sticking the landing” estimator of ELBO gradient 
1 def elbo-pd(a): 

e€ ~ qol€) 

z=r(w,e) 

wy’ = stop-gradient (4) 

return log p(z, D) — dy’ (z|D) 
def elbo-grad-pd (a): 

return grad(elbo-pd(w)) 


Note that the STL estimator is not always better than the “standard” estimator. In [GD20], they 
propose to use a weighted combination of estimators, where the weights are optimized so as to reduce 
variance for a fixed amount of compute. 


10.3.3.2 Example: reparameterized SVI for GMMs 


In this section, we use reparameterized SVI to fit a Gaussian mixture model. We will marginalize out 
the discrete latent variables, so just need to approximate p(@|D). We choose a factored variational 
posterior that is conjugate to the likelihood, but is also reparameterizable, so we can fit the posterior 
with SGD instead of having to use coordinate ascent (Section 10.2.1). 
For simplicity, we assume diagonal covariance matrices. Thus the likelihood for one data point, 
x € RP, is 
K 
p(x|0) = X` TN (z|up, diag(Ax)') (10.131) 
k=1 


2. The expression “to stick a landing” means to land firmly on one’s feet after performing a gymnastics move. In the 
current context, the analogy is this: if the variational posterior is optimal, so qy(z|D) = p(z|D), then we want our 
objective to be 0, and not to “wobble” with Monte Carlo noise. 

3. The difference is that the path derivative version ignores the score function. This can be achieved by using 
log dy (z|D), where a’ is a “disconnected” copy of % that does not affect the gradient. 
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(a) (b) (c) 


Figure 10.10: SVI for fitting a mixture of 3 Gaussians in 2d. (a) 3000 training points. (b) Fitted density, 
plugging in the posterior mean parameters. (c) Kernel density estimate fit to 10,000 samples from q( m: \po). 
Generated by svi_gmm_ demo_ 2d.ipynb. 


where up = (Uk1;-:-,Hkp) are the means, Ag = (Akı, ---, Akp) are the precisions, and 7 = 
(T1,..., Tg) are the mixing weights. We use the following prior for these parameters: 
K D 
p(9) = T J [ Miralo, 1)Ga(Axal5, 5) | Dir(r|1) (10.132) 
k=1d=1 


We assume the following mean field posterior: 


KD 
qa(0lpeo) = T J [ MluralMmra, ska) Ga(Aralara, Bra) Dir(z|c) (10.133) 


k=1 d=1 


where Wg = (M1:K,1:D; 81:K,1:D; 01:K,1:D; B1:K 1:p, €) are the variational parameters for 0. 
We can compute the ELBO using 


N 


L(Po|D) = Eq(o\p5) È lene) — Dux (4(0|%o) || p(9)) (10.134) 


n=1 


31 We can approximate the first term using minibatching. Since q(6|tg) is reparameterizable (see 
32 Section 10.3.3), we can sample from it and push gradients inside. If we use a single posterior sample 


per minibatch, 0° ~ q(@|wW9), we get 


B 
N s 
Viol (HolD) ~ F Y Vite log pll") — Vy, Dra (a8 lo) I| pC) (10.135) 
b=1 
We can now optimize this with SGD. 
Figure 10.10 gives an example of this in practice. We generate a dataset from a mixture of 3 
Gaussians in 2d, using wt = [2,0], ws = [-2,—4], 43 = [—2,4], precisions Až, = 1, and uniform 


41 mixing weights, m* = [1/3,1/3,1/3]. Figure 10.10a shows the training set of 3000 points. We fit 
42 this using SVI, with a batch size of 500, for 1000 epochs, using the Adam optimizer. Figure 10.10b 
43 shows the predictions of the fitted model. More precisely, it shows p(æ|0), where 8 = Eola) [0]. 
44 Figure 10.10c shows a kernel density estimate fit to 10,000 samples from q(t1,|t_). We see that the 


posterior mean is E [u] ~ [—2, —4]. Due to label switching unidentifiability, we see this matches u3 
rather than pj. 
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10.3. FIXED-FORM VI 


10.3.4 Gaussian VI 


The most widely used RVI approximation is when qy(z) is a Gaussian, where q = (u, ©). Following 
[TND21], we call this Gaussian VI. We give a briefy summary below. For a theoretical analysis of 
the computational / statistical tradeoff with using this form of posterior approximation, see [Bha+-22]. 


10.3.4.1 Full-rank Gaussian VI 


In this section, we represent the covariance using its Cholesky decomposition, © = LL’, where 
L = tril(l) is a lower triangular matrix, as in [TLG14; TN18]. The variational parameters are 
p = (p,l). The noise transformation has the form 


z~ N(u, ©) 4 z=u+Le (10.136) 


where e ~ N (0, I), so r(p, €) = w+ Le. 
For any given sample z, we can use automatic differentiation to compute the derivative of the 
sampled ELBO: 


Vply(2()) = Vy log p(z(4), D) — Vy log ay (z()) (10.137) 


10.3.4.2 Low-rank Gaussian VI 


In high dimensions, an efficient alternative to using a Cholesky decomposition is the factor decompo- 
sition 


£ = BB' + C? (10.138) 


where B is the factor loading matrix of size d x f, where f < d is the number of factors, d is 
the dimensionality of z, and C = diag(c1,...,ca). This reduces the total number of variational 
parameters from d + d(d + 1)/2 to (f + 2)d. In [ONS18], they called this approach VAFC for 
Variational Approximation with Factor Covariance. 

In the special case where f = 1, the covariance matrix becomes 


X = bb! + diag(c?) (10.139) 


In this case, it is possible to compute the natural gradient (Section 6.4) of the ELBO in closed form in 
O(d) time, as shown in [Tra+20b; TND21], who call the approach NAGVAC-1 (Natural Gradient 
Gaussian variational approximation). This can result in much faster convergence than following the 
normal gradient. 

In particular, let g = (gi,92,g3) be the regular gradient of the ELBO wrt u, b and c. Let 
v = Ê — 26? Oct, v = BOC, ki = TL, b/e, and Kg = 4(1 + 74, v3,/v1i)71. Then the 
natural gradient is given by 


(g1b)b+ Ogi 
gt = "$41 ((g}b)b + ° © go) (10.140) 
$0, Ogs + k2[(vz ' © v2)" gs](vz ' © v2) 
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C] analytic 


: mean-field 


V7"! low_rank 


L. full_rank 


Figure 10.11: Gaussian variational approximation to a Gaussian posterior for the mean of a 2d Gaussian. 
Generated by gaussian  2d_ vi.ipynb. 


10.3.4.3 Example: GVI for a linear Gaussian system 


As a sanity check, we try to approximate the posterior mean of a multivariate Gaussian with fixed 
covariance. We use a Gaussian prior, so the joint distribution has the form 


p(z|D) x p(z wN (yn|z, =) (10.141) 


n=1 


where the measurement noise © is a fixed, non-diagonal matrix. If we use a conjugate prior, 
N(z| m,V), then the exact posterior, p(z|D) = N(z| m,V), can be computed analytically, as 
discussed in Section 3.3.1. 

We also compute a Gaussian variational approximation, q(z|% „, Ys), where py, is either full-rank, 
diagonal, or rank-1 plus diagonal. The results are shown in Figure 10.11. We see that the full-rank 
approximation matches the true posterior, as expected. However, the diagonal approximation is 
overconfident, which is a well-known flaw of variational inference. 


10.3.5 Automatic differentiation VI 


2 To apply Gaussian VI, we need to transform constrained parameters (such as variance terms) to 


unconstrained form, so they live in R?. This technique can be used for any distribution for which 
we can define a bijection to R?. This approach is called automatic differentiation variational 
inference or ADVI [Kuc+16]. We give the details below. 


10.3.5.1 Basic idea 


Our goal is to approximate the posterior p(@|D) x p(@)p(D|@), where 0 € © lives in some D- 
dimensional constrained parameter space. Let T : © > R? be a bijective mapping that maps from 


41 the constrained space to some unconstrained space, with inverse T~! : RP — © that maps to the 
42 constrained domain of the prior. Let z = T(@) be the unconstrained latent variables. We will use a 
43 block-factored Gaussian variational approximation to the posterior for z, i.e.,: 


q(z|p) = jivi (Z5| Hy, Do) (10.142) 


b=1 
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10.3. FIXED-FORM VI 


where Yp = (H1:B, Xı:B). 
By the change of variable formula Equation (2.286), we have 


p(z) = p(L~*(z))| det(Jr- (2))| (10.143) 
where J7-1 is the Jacobian of the inverse mapping z > 0. Hence the ELBO becomes 
LGP) = Beng(zlp) [PIT (2)) + log p(T- (2)) + log | det(Ir-: (2))|] + Hp) (10.144) 


We can use a Monte Carlo approximation of the expectation over z, together with the reparame- 
terization trick, which (in the fully diagonal case) replaces z ~ q(z|w) with z = u + o © e where 
e€ ~ N (0,1) and o = c1:p. We can compute the final entropy term in closed form. In the case of a 
fully diagonal approximation, it follows from Equation (5.96) that 


D 
H(%) = X` 2log(o;) + const (10.145) 


i=1 


Since the objective is stochastic, we can use SGD to optimize it. However, [Ing20] propose 
deterministic ADVI, in which the samples €s ~ N (0, I) are held fixed during the optimization 
process. This is called the common random numbers trick (Section 11.6.1), and makes the objective 
a deterministic function; this allows for the use of more powerful second-order optimization methods, 
such as BFGS. (Of course, if the dataset is large, we might need to use minibatch subsampling, which 
reintroduces stochasticity.) 


10.3.5.2 Example: ADVI for Beta-Binomial model 


To illustrate ADVI, we consider the 1d beta-binomial model from Section 7.4.4. We want to 
approximate p(6|D) using the prior p(@) = Beta(6|a, b) and likelihood p(D|@) = [], Ber(y;|@), where 
the sufficient statistics are N; = 10, No = 1, and the prior is uninformative, a = b = 1. We use the 
transformation 0 = T~'(z) = a(z), and optimize the ELBO with SGD. The results of this method 
are shown in Figure 7.3 and show that the Gaussian fit is a good approximation, despite the skewed 
nature of the posterior. 


10.3.5.3 Example: ADVI for GMMs 


In this section, we use ADVI to approximate the posterior of the parametrs of a mixture of Gaussians. 
The difference from the VBEM algorithm of Section 10.2.6 is that we use ADVI combined with a 
Gaussian variational posterior, rather than using a mean field approximation defined by a product of 
conjugate distributions. 

To apply ADVI, we marginalize out the discrete local discrete latents Mn € {1,..., K} analytically, 
so the likelihood has the form 


N K 


P(D|O) = | | [X TN (yn ley, diag(£x)) (10.146) 


n=1 Lk=1 


We use an uniformative Gaussian prior for the up, a uniform LKJ prior for the Lz, a log-Normal 
prior for the o,, and a uniform Dirichlet prior for the mixing weights m. (See [Kuc+16, Fig 21] 
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Figure 10.12: Posterior over the mixing weights (histogram) and the means and covariances of each Gaussian 
mixture component, using K = 10, when fitting the model to the Old Faithful dataset from Figure 10.7. (a) 
MAP approximation. (b-d) 3 samples from the Gaussian approximation. The intensity of the shading is 
proportional to the mixture weight. Generated by gmm_advi_ bijax.ipynb. 


31 for a definition of the model in STAN syntax.) The posterior approximation for the unconstrained 
32 parameters is a block-diagonal gaussian. q(z) = V(z|t,,,Ws:), where the unconstrained parameters 
33 are computed using suitable bijections (see code for details). 


We apply this method to the Old Faithful dataset from Figure 10.7, using K = 10 mixture 


35 components. The results are shown in Figure 10.12. In the top left, we show the special case where 
36 we constrain the posterior to be a MAP estimate, by setting Ys = 0. We see that there is no sparsity 
37 in the posterior, since there is no Bayesian “Occam factor” from marginalizing out the parameters. 
38 In panels (c-d), we show 3 samples from the posterior. We see that the Bayesian method strongly 
39 prefers just 2 mixture components, although there is a small amount of support for some other 
40 Gaussian components (shown by the faint ellipses). 


= 10.3.5.4 More complex posteriors 


45 We can combine ADVI with any of the improved posterior approximations that we discuss in 
46 Section 10.4, such as Gaussian mixtures [Mor-+21b] or normalizing flows [ASD 20]. 
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10.3. FIXED-FORM VI 


10.3.6 Amortized inference 


Suppose we want to perform parameter estimation in a model with local latent variables: 


N 
Ô = argmax lo Bins Zn: 10.147 
gn 2 s3 rol ) ( ) 


To compute the marginal likelihood, we first need to compute the posteriors pe(Zn|£n) for each 
example n; the normalization constant then gives us log pg(#,). When using BBVI or RVI for this, 
we therefore have to solve an optimization problem for each q(zZ,|,,), which can be slow. 

An alternative approach is to train a model, known as an inference network or recognition 
network, to predict w,, from the observed data, £», using Y, = iE" (En). This technique is known 
as amortized inference [GG14], or inference compilation [LBW 17], since we are reducing the 
cost of per-example time inference by training a model that is shared across all examples. (See also 
[Amo22] for a general discussion of amortized optimization.) For brevity, we will write 


@(znlt,,) = (znl fE (En)) = qo(Zn|En) (10.148) 
The “amortized ELBO”, for a model with local latents and fixed global parameters, becomes 
ee 


L($,8[D) = 5 XO [Eas(enlan) Log po (En, Zn) — log qg(zlan)]] (10.149) 


n=1 


We can approximate this by sampling a single data point x, ~ pp, and then sampling a single latent 
Zn ~ Gb(Zn|#n), as in DSVI, to get 


Llo, O|E£n, Zn) = log po (£n, Zn) = log delZn) (10.150) 


(We call this the “per-sample ELBO”, although [Ble17] call it the instantaneous ELBO.) If the 
posteriors are reparameterizable, we can push gradients inside and then apply SGD. See Algorithm 24 
for the resulting pseudocode. 


Algorithm 24: Amortized SVI 


1 Initialize 0, d 

2 repeat 

3 Sample £n ~ pp 

4 Sample Zn ~ q4(Z|@n) 

5 | 0:=04+nVol(¢, O|an, Zn) 
6 P := p + NV Ll, On, Zn) 
7 Update learning rate 7 

8 until converged 


This method is very widely used for fitting LVMs, e.g., for VAEs (see Section 21.2), for topic 
models [SS17a], for probabilistic programming [RHG16], for CRFs [TG18], etc. However, the use 
of an inference network can result in a suboptimal setting of the local variational parameters w,,. 
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This is called the amortization gap [CLD18]. We can close this gap by using the inference network 
to warm-start an optimizer for w,,; this is known as semi-amortized VI [Kim+18c]. The key 
insight is that the local SVI procedure is itself differentiable, so the inference network and generative 
model can be trained end-to-end. (See also [MYM18], who propose a closely related method called 
iterative amortized inference.) 

An alternative approach is to use the inference network as a proposal distribution. If we combine 
this with importance sampling, we get the IWAE bound of Section 10.5.1. If we use this with 
Metropolis Hastings, we get a VI-MCMC hybrid (see Section 10.4.5). 


10.4 More accurate variational posteriors 


In general, we can improve the tightness of the ELBO lower bound, and hence reduce the KL 
divergence of our posterior approximation, if we use more flexible posterior families (although 
optimizing within more flexible families may be slower, and can incur statistical error if the sample 
size is low [Bha+21]). In this section, we give several examples of more accurate variational posteriors, 
going beyond fully factored mean field approximations. 


10.4.1 Structured mean field 


The mean field assumption is quite strong, and can sometimes give poor results. Fortunately, 
sometimes we can exploit tractable substructure in our problem, so that we can efficiently handle 
some kinds of dependencies between the variables in the posterior in an analytic way, rather than 


26 assuming they are all independent. This is called the structured mean field approach [SJ95]. 


A common example arises when appling VI to time series models, such as HMMs, where the 


28 latent variables within each sequence are usually highly correlated across time. Rather than as- 
29 suming a fully factorized posterior, we can treat each sequence 2, 1.7 as a block, and just assume 
30 independence between blocks and the parameters: q(Z1.n,1:7,9) = q(0) i eae q(Zn,1:7), where 
31 q(2n1:7) = [[; ¢(2n,t|/2n,t-1). We can compute the joint distribution q(z,,1.r), taking into account 
32 the dependence between time steps, using the forwards-backwards algorithm. For details, see 


[JW14; Fot+14]. A similar approach was applied to the factorial HMM model, as we discuss in 


34 Supplementary Section 10.3.2. 


An automatic way to derive a structured variational approximation to a probabilistic model, 


36 specified by a probabilistic programming language, is discussed in [AHG20]. 


10.4.2 Hierarchical (auxiliary variable) posteriors 


41 Suppose q¢(z|x) = [p ¢¢(Zx|x) is a factorized distribution, such as a diagonal Gaussian. This does 


not capture dependencies between the latent variables (components of z). We could of course use a 


43 full covariance matrix, but this might be too expensive. 


An alternative approach is to use a hierarchical model, in which we add auxiliary latent variables 
a, which are used to increase the flexibility of the variational posterior. In particular, we can still 
assume g¢(z|x, a) is conditionally factorized, but when we marginalize out a, we induce dependencies 
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between the elements of z, i.e., 


ues / go(zl@, a)ae(ale)da + TT] ag (2x\2) (10.151) 


k 


This is called a hierarchical variational model [Ran16], or an auxiliary variable deep gener- 
ative model [Maa+ 16]. 

In [TRB16], they model gg(z|z,@) as a Gaussian process, which is a flexible nonparametric 
distribution (see Chapter 18), where a are the inducing points. This combination is called a 
variational GP. 


10.4.3 Normalizing flow posteriors 


Normalizing flows are a class of probability models which work by passing a simple source distribution, 
such as a diagonal Gaussian, through a series of nonlinear, but invertible, mappings f to create a 
more complex distribution. This can be used to get more accurate posterior approximations than 
standard Gaussian VI, as we discuss in Section 23.1.2.2. 


10.4.4 Implicit posteriors 


In Chapter 26, we discuss implicit probability distributions, which are models which we can sample 
from, but which we cannot evaluate pointwise. For example, consider passing a Gaussian noise term, 
zo ~ N(0,1), through a nonlinear, non-invertible mapping f to create z = f(z); it is easy to sample 
from q(z), but it is intractable to evaluate the density q(z) (unlike with flows). This makes it hard to 
evaluate the log density ratio log pe(z)/qy(z|x), which is needed to compute the ELBO. However, we 
can use the same method as is used in GANs (generative adversarial networks, Chapter 26), in which 
we train a classifier that discriminates prior samples from samples from the variational posterior by 
evaluating T(x, z) = log qy(z|a) — log pe(z). See e.g., [TR19] for details. 


10.4.5 Combining VI with MCMC inference 


There are various ways to combine variational inference with MCMC to get an improved approximate 
posterior. In [SKW15], they propose Hamiltonian Variational Inference, in which they train an 
inference network to initialize an HMC sampler (Section 12.5). The gradient of the log posterior (wrt 
the latents), which is needed by HMC, is given by 


Vz log pe(z|x) = Vz log [pe(x, z) — log pe(x)| = Vz log po (x, z) (10.152) 


This is easy to compute. They use the final sample to approximate the posterior qg(z|x). To compute 
the entropy of this distribution, they also learn an auxiliary inverse inference network to reverse the 
HMC Markov chain. 

A simpler approach is proposed in [Hof17]. Here they train an inference network to initialize an 
HMC sampler, using the standard ELBO for ¢, but they optimize the generative parameters @ using 
a stochastic approximation to the log marginal likelihood, given by log pg(z, x) where z is a sample 
from the HMC chain. This does not require learning a reverse inference network, and avoids problems 
with variational pruning, since it does not use the ELBO for training the generative model. 
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10.5 Tighter bounds 


Another way to improve the quality of the posterior approximation is to optimize q wrt a bound that 
is a tighter approximation to the log marginal likelihood compared to the standard ELBO. We give 
some examples below. 


10.5.1 Multi-sample ELBO (IWAE bound) 


In this section, we discuss a method known as the importance weighted autoencoder or IWAE 
[BGS16], which is a way to tighten the variational lower bound by using self-normalized importance 
sampling (Section 11.5.2). (It can also be interpreted as standard ELBO maximization in an expanded 
model, where we add extra auxiliary variables [CMD17; DS18; Tuc+19].) 

Let the inference network q¢(z|ax) be viewed as a proposal distribution for the target posterior 


pe(z|a). Define wš = pelta, as the unnormalized importance weight for a sample, and w, = 


wr / ae , w%,) as the normalized importance weights. From Equation (11.43) we can compute an 
estimate of the marginal likelihood p(a) using 


Ss 
1 
slæl|zı:s) 15 Po(T, Zs 3 -1 Yws (10.153) 
k=1 


Si (2 


This is unbiased, i.e. Ey,(2,.5)2) [Ps (£|21:s)] = p(x), where qẹ(z1:s|£) = Ma qe(Zs|£). In addition, 
since the estimator is always positive, we can take logarithms, and thus obtain a stochastic lower 
bound on the log likelihood: 


Ss 
1 5 
Ls(¢, 6|x) = ag zı:s|æ) foz E >.)| = Vag(Z1:9|@) [log Bs (Z1:s)| (10.154) 
s= 
< log Eg, (z1:s|æ) [Ps(z1:s)] = log p(x) (10.155) 


2 where we used Jensen’s inequality in the penultimate line, and the unbiased property in the last line. 
31 This is called the multi-sample ELBO or IWAE bound [BGS16]. If S = 1, Lg reduces to the 


24 standard ELBO: 


E16 |S |g |S 18 [& IE Is 


Z£ 
E1($,0læ) = Ega) logu] = | a(zlæ) oe me Taa (10.156) 


= One can show [BGS16] that increasing the number of samples S is guaranteed to make the bound 
= tighter, thus making it a better proxy for the log likelihood. Intuitively, averaging the S samples 
= inside the log removes the need for every sample Zs to explain the data x. This encourages the 
= proposal distribution q to be less concentrated than the single-sample variational posterior. 


— 10.5.1.1 Pathologies of optimizing the TWAE bound 


IS IS le IF b IS 
NID low A Jw Iw 


43 Unfortunately, increasing the number of samples in the IWAE bound can decrease the signal to noise 
44 ratio, resulting in learning a worse model [Rai+18a]. Intuitively, the reason this happens is that 
45 increasing S reduces the dependence of the bound on the quality of the inference network, which 
46 makes the gradient of the ELBO wrt @ less informative (higher variance). 
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10.5. TIGHTER BOUNDS 


One solution to this is to use the doubly reparameterized gradient estimator [TL18b]. 
Another approach is to use alternative estimation methods that avoid ELBO maximization, such 
as using the thermodynamic variational objective (see Section 10.5.2) or the reweighted wake sleep 
algorithm (see Section 10.6). 


10.5.2 The thermodynamic variational objective (TVO) 


In [MLW 19; Bre+20b], they present the thermodynamic variational objective or TVO. This 
is an alternative to IWAE for creating tighter variational bounds, which has certain advantages, 
particularly for posteriors that are not reparameterizable (e.g., discrete latent variables). The 
framework also has close connections with the reweighted wake sleep algorithm from Section 10.6, as 
we will see in Section 10.5.3. 

The TVO technique uses thermodynamic integration, also called path sampling, which is 
a technique used in physics and phylogenetics to approximate intractable normalization constants 
of high dimensional distributions (see e.g., [GM98; LP06; FP08]). This is based on the insight 
that it is easier to calculate the ratio of two unknown constants than to calculate the constants 
themselves. This is similar to the idea behind annealed importance sampling (Section 11.5.4), but TI 
is deterministic. For details, see [MLW19; Bre+20b]. 


10.5.3 Minimizing the evidence upper bound 


Recall that the evidence lower bound or ELBO is given by 


L(9, ¢|x) = log po (x) — Dri (qe(2|)) || po(z|a)) < log po (æ) (10.157) 


By analogy, we can define the evidence upper bound or EUBO as follows: 


EUBO(, |x) = log po (£) + Dux (po(z|æ) || a¢(z|a)) = log po (x) (10.158) 


Minimizing this wrt the variational parameters @, as an alternative to maxmimizing the ELBO, was 
proposed in [MLW19], where they showed that it can sometimes converge to the true log pg(a) faster. 

The above bound is for a specific input æ. If we sample æ from the generative model, and 
minimize Ep, (æ) [EUBO(@, ¢|x)] wrt Ø, we recover the sleep phase of the wake-sleep algorithm (see 
Section 10.6.2). 

Now suppose we sample æ from the empirical distribution, and minimize Ep» (2) [EUBO(@, ¢|x)| 
wrt @. To approximate the expectation, we can use self-normalized importance sampling, as in 
Equation (10.173), to get 


S 
VgEUBO(8, |æ) a v; Vo log qg(z°|x) (10.159) 


where Ws = w) /($, w)), and w) = Hee This is equivalent to the “daydream” update (aka 


“wake-phase @ update”) of the wake-sleep algorithm (see Section 10.6.3). 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO l% IN ID Jot e IW IN IR 


N JIN IN IN JIN Jw IN JIN JIN JN 
l le N IS a e IS IS IS S Is le IR le la le le Is E Is 


WIN |e Io 


444 


10.6 Wake-sleep algorithm 


So far in this chapter we have focused on fitting latent variable models by maximizing the ELBO. 
This has two main drawbacks. First, it does not work well when we have discrete latent variables, 
because in such cases we cannot use the reparamterization trick; thus we have to use higher variance 
estimators, such as REINFORCE (see Section 10.3.2). Second, even in the case where we can use 
the reparamterization trick, the lower bound may not be very tight. We can improve the tightness 
by using the IWAE multi-sample bound (Section 10.5.1), but paradoxically this may not result in 
learning a better model, for reasons discussed in Section 10.5.1.1. 

In this section, we discuss a different way to jointly train generative and inference models, which 
avoids some of the problems with ELBO maximization. The method is known as the wake-sleep 
algorithm [Hin+95; BB15b; Le+19; FT19]. because it alternates between two steps: in the wake 
phase, we optimize the generative model parameters 0 to maximize the marginal likelihood of the 
observed data (we approximate log pg(a) by drawing importance samples from the inference network); 
and in the sleep phase, we optimize the inference model parameters @ to learn to invert the generative 
model by training the inference network on labeled (æ, z) pairs, where x are samples generated by 
the current model parameters. This can be viewed as a form of adaptive importance sampling, 
which iteratively improves its proposal, while simultaneously optimizing the model. We give further 
details below. 


10.6.1 Wake phase 


In the wake phase, we minimize the KL divergence from the empirical distribution to the model’s 
distribution: 


L(0) = Dx (po (x) || pe(x)) = Epp (z) [- log pe(x)] + const (10.160) 

28 where pọ(x) = f pe(z)pe(x|z)dz. This is equivalent to maximizing the likelihood of the observed 
data: 

£(0) = Epp (a) [log po (æ)] (10.161) 


Since the log marginal likelihood log pg(x) cannot be computed exactly, we will approximate it. 


— Tn the original wake-sleep paper, they proposed to use the ELBO lower bound. In the reweighted 


IS Iè [A IÈ [è IS [e [8 1S 18 IS 18 I8 IE 


— wake-sleep (RWS) algorithm of [BB15b; Le+19], they propose to use the IWAE bound from 
= Section 10.5.1 instead. In particular, if we draw S samples from the inference network, z, ~ qẹ(z|æ), 
— we get the following estimator: 


S 
£(6|b, x) = log (z 5 v.) (10.162) 


zə Where Ws = EE 
a3. We now discuss how to compute the gradient of this objective. Using the log-derivative trick, we 
have that 
1 
Ve log ws = — Vows = Vo log pe (z, zs) (10.163) 
Ws 
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10.6. WAKE-SLEEP ALGORITHM 


Hence 
Vol(O\d, x) : (3 hy (10.164) 
8 tI ToS \o Ws ; 
3 Daci Ws S s=l 
s 
1 Po(Zs, £) 
=— —— 2 Vo log po (zs, £) (10.165) 
Peat (x q(zs|æ) 
s 
= X WVo log po (zs, £) (10.166) 


where W, = w/s Ws). 


10.6.2 Sleep phase 


In the sleep phase, we try to minimize the KL divergence between the true posterior (under the 
current model) and the inference network’s approximation to that posterior: 


Lib) = Epo (œ) [Dri (po (2/2) || qo(z|2))] = Epo (zx) [~ log ao (z|æ£)] + const (10.167) 


Equivalently, we can maximize the following loglikelihood objective: 


&(|0) = Ez x)~po(z,x) [log qo (z|2)] (10.168) 


where po(z, £x) = pe(z)pe(a|z). We see that the sleep phase amounts to maximum likelihood training 
of the inference network based on samples from the generative model. These “fantasy samples”, 
created while the network “dreams”, can be easily generated using ancestral sampling (Section 4.2.5). 
If we use S such samples, the objective becomes 


Ss 
(18) = $ Ý log a6(2%) (10.169) 


s=l1 
where (z1, £4) ~ pe(z, a). The gradient of this is given by 


S 
1 
Vel($l8) = 5 X Veo log ag(zslæs) (10.170) 


s=1 


We do not require qẹ(z'|x) to be reparameterizable, since the samples are drawn from a distribution 
that is independent of @. This means it is easy to apply this method to models with discrete latent 
variables. 


10.6.3 Daydream phase 


The disadvantage of the sleep phase is that the inference network, q¢(z|a), is trying to follow a 
moving target, pe(z|a). Furthermore, it is only being trained on synthetic data from the model, 
not on real data. The reweighted wake-sleep algorithm of [BB15b] proposed to learn the inference 
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network by using real data from the empirical distribution, in addition to fantasy data. They call 
the case where you use real data the “wake-phase q update”, but we will call it the “daydream 
phase”, since, unlike sleeping, the system uses real data a to update the inference model, instead of 
fantasies.* [Le+19] went further, and proposed to only use the wake and daydream phases, and to 
skip the sleep phase entirely. 

In more detail, the new objective which we want to minimize becomes 


L£(|9) = Epp (a) [Dx (po(2|2) || ao(2]@))] (10.171) 
We can compute a single sample approximation to the negative of the above expression as follows: 
£(|6, x) = Ep, (z\x) [log qg (zlæ)] (10.172) 


where x ~ pp. We can approximate this expectation using importance sampling, with qẹ as the 
proposal. This results in the following estimator of the gradient for each datapoint: 


S 
Vel(dl@, x) = J rolex) Vo 10g qe(z|a)dz % XC U,Vo log q¢(Zs|x) (10.173) 
s=1 


where Zs ~ q¢(Zs|x) and Ws are the normalized weights. 

We see that Equation (10.173) is very similar to Equation (10.170). The key difference is that in 
the daydream phase, we sample from (a, zs) ~ pp(a)q¢(z|x), where æ is a real data point, whereas 
in the sleep phase, we sample from (a, z4) ~ pe(z,x), where x is generated data point. 


10.6.4 Summary of algorithm 


Algorithm 25: One SGD update using wake-sleep algorithm. 


Sample æ, from dataset 
Draw S samples from inference network: zs ~ q(z|%n) 


+ : x — P(En,Zs) 
Compute unnormalized weights: w, = EAER 


AÀ Ù N e 


Compute normalized weights: Ws = ss ay 
s/=1 Ws! 

Optional: Compute estimate of log likelihood: log p(a,) = log(4 — Ws) 

Wake phase: Update @ using =| Ws Vo log pe(Zs, £n) 

Daydream phase: Update @ using an WsV log q¢(Zs|@n) 

Optional sleep phase: Draw S samples from model, (x4, z4) ~ pe(x, z) and update @ using 


S 
3 Deni Vo log ag(;|@5) 


ar Q A 


We summarize the RWS algorithm in Algorithm 25. The disadvantage of the RWS algorithm is 


=< that it does not optimize a single well-defined objective, so it is not clear if the method will converge, 
= in contrast to ELBO maximization. On the other hand, the method is fairly simple, since it consists 
= of two alternating weighted maximum likelihood problems. It can also be shown to “sandwich” a 


46 4. We thank Rif Saurous for suggesting this term. 
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10.7. EXPECTATION PROPAGATION (EP) 


lower and upper bound of the log marginal likelihood. We can think of this in terms of the two joint 
distributions pe(x, z) = pe(z)pe(a|z) and qp, (£, z) = pp(x)qe(z|x): 


wake phase min Dx (aD, (£, z) || pe(«, z)) (10.174) 


daydream phase min Di (pe(x, z) || ¢0,¢(x, z)) (10.175) 


10.7 Expectation propagation (EP) 


One problem with lower bound maximization (i.e., standard VI) is that we are minimizing Dxz (q || p), 
which induces zero-forcing behavior, as we discussed in Section 5.1.3.3. This means that q(z|x) 
tends to be too compact (over-confident), to avoid the situation in which q(z|x) > 0 but p(z|x) = 0, 
which would incur infinite KL penalty. 

Although zero-forcing can be desirable behavior for some multi-modal posteriors (e.g., mixture 
models), it is not so reasonable for many unimodal posteriors (e.g., Bayesian logistic regression, or 
GPs with log-concave likelihoods). One way to avoid this problem is to minimize Dga (p || q), which 
is zero-avoiding, as we discussed in Section 5.1.3.3. This tends to result in broad posteriors, which 
avoids overconfidence. In this section, we discuss expectation propagation or EP [Min01b], which 
can be seen as a local approximation to Dx (p || q). 


10.7.1 Algorithm 


We assume the exact posterior can be written as follows: 


K 
p(OID) = 5-08), PO) = po() TI fel) (10.176) 
P k=1 


where (0) is the unnormalized posterior, po is the prior, fk corresponds to the k’th likelihood term 
or local factor (also called a site potential). Here Zp = p(D)Zp is the normalization constant for 
the posterior, where Zo is the normalization constant for the prior. To simplify notation, we let 
fo(@) = po(@) be the prior. 

We will approximate the posterior as follows: 


1 


K 
a(0) = 5-48), (9) = vo(9) |] f() (10.177) 
4 k=1 


where fk € Q is the approximate local factor, and Q is some tractable family in the exponential 
family, usually a Gaussian [Gel+14b]. 

We will optimize each Îi in turn, keeping the others fixed. We initialize each Ĵi using an 
uninformative distribution from the family Q. so q(@) = po(0). 

To compute the new local factor fnew, we proceed as follows. First we compute the cavity 


distribution by deleting the f; from the approximate posterior by dividing it out: 


cavity 0)= 48) x F 0 10.178 
di (8) HO Ty Ft ) ( ) 
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This division operation can be implemented by subtracting the natural parameters, as explained in 
Section 2.2.7.2. The cavity distribution represents the effect of all the factors except for f; (which is 
approximated by fi): 

Next we (conceptually) compute the tilted distribution by multiplying the exact factor f; onto 
the cavity distribution: 


i 1 cavi 
ai™®® (0) = 5 fi(0) a (0) oe 


where Z; = f ©" (0) f;(0)d0 is the normalization constant for the tilted distribution. This is the 
result of combining the current approximation, excluding factor i, with the exact f; term. 

Unfortunately, the resulting tilted distribution may be outside of our model family (e.g., if we 
combine a Gaussian prior with a non-Gaussian likelihood). So we will approximate the tilted 
distribution as follows: 


ar" (0) = proj(qi") & argmin D(gi™° |0) (10.180) 
qe 


This can be thought of as projecting the tilted distribution into the approximation family. If 
D(qi'*e4||q) = Dri (g#™*°9 || q), this can be done by moment matching, as shown in Section 5.1.3.4. 
For example, suppose the cavity distribution is Gaussian, qY" (0) = N.(0|r_;,Q_;), using the 


canonical parameterization. Then the log of the tilted distribution is given by 
; 1 
log gi''*e4(@) = a log f;(0) — 59 2:8 +r! 6+ const (10.181) 


Let @ be a local maximum of this objective. If Q is the set of Gaussians, we can compute the 
projected tilted distribution as a Gaussian with the following parameters: 


Qu = -V3 log qfi"**"(@)le_a, ru = Quê (10.182) 


2 This is called Laplace propagation [SVE04]. For more general distributions, we can use Monte 
“= Carlo approximations; this is known as blackbox EP [HL+16a; Li+18c]. 


Finally, we compute a local factor that, if combined with the cavity distribution, would give the 
same results as this projected distribution: 
a) 


fnew qi 
(0) = 


; 10.183 
qa (0) ( ) 


We see that q2 (0) fre~(@) = q?"*!(@), so combining this approximate factor with the cavity 
distribution results in a distribution which is the best possible approximation (within Q) to the 


4 results of using the exact factor. 


“10.7.2 Example 


Figure 10.13 illustrates the process of combining a very non-Gaussian likelihood f; with a Gaussian 
cavity prior g°*”" to yield a nearly Gaussian tilted distribution qt!!*e¢, which can then be approximated 


by a Gaussian using projection. 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o e lw N |e 


Io IN ls la le le Ie IE 


IS Is 


10.7. EXPECTATION PROPAGATION (EP) 


Tilted distribution, p(yi|8)g-i(8) 


Cavity distribution, g-i(8) 


Likelihood factor, p(yi|®) 


cavity __ 


Figure 10.13: Combining a logistic likelihood factor fi = p(yi|@) with the cavity prior, q° g—i(@), to get 
the tilted distribution, qf""*°" = p(y:|0)g-i(0). Adapted from Figure 2 of [Gel+14b]. 


Thus instead of trying to “Gaussianize” each likelihood term f; in isolation (as is done, e.g., in 
EKF), we try to find the best local factor Îi (within some family) that achieves approximately the 
same effect, when combined with all the other terms (represented by the cavity distribution, q—;), as 
using the exact factor fi. That is, we choose a local factor that works well in the context of all the 
other factors. 


10.7.3 EP as generalized ADF 


We can view EP as a generalization of the ADF algorithm discussed in Section 8.9. ADF is a 
form of sequential Bayesian inference. At each step, it maintains a tractable approximation to 
the posterior, q(z) € Q, updates it with the likelihood from the next observation, pe+ı(z) œ 
q(z)p(az|z), and then projects the resulting updated posterior back to the tractable family using 
G41 = argmingeg Dux (Pt41 || q). ADF minimizes KL in the desired direction. However, it is a 
sequential algorithm, designed for the online setting. In the batch setting, the method can given 
different results depending on the order in which the updates are performed. In addition, if we 
perform multiple passes over the data, we will include the same likelihood terms multiple times, 
resulting in an overconfident posterior. EP overcomes this problem. 


10.7.4 Optimization issues 


In practice, EP can be numerically unstable. For example, if we use Gaussians as our local factors, 
we might end up with negative variance when we subtract the natural parameters. To reduce the 
chance of this, it is common to use damping, in which we perform a partial update of each factor 
with a step size of ô. More precisely, we change the final step to be the following: 


. ô 
Eo) = (A0) (tae ote!) 


This can be implemented by scaling the natural parameters by 6. [ML02] suggest 6 = 1/K as a safe 
strategy (where K is the number of factors), but this results in very slow convergence. [Gel+14b] 
suggest starting with 6 = 0.5, and then reducing to 6 = 1/K over K iterations. 
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In addition to numerical stability, there is no guarantee that EP will converge in its vanilla form, 
although empirically it can work well, especially with log-concave factors f; (e.g., as in GP classifiers). 


10.7.5 Power EP and a-divergence 


We also have a choice about what divergence measure D(q!!!**¢||q) to use when we approximate 
the tilted distribution. If we use Dz (qe | q); we recover classic EP, as described above. If 
we use DgL (q | ge). we recover the reverse KL used in standard variational inference. We can 
generalize the above results by using a-divergences (Section 2.7.1.2), which allow us to interpolate 
between mode seeking and mode covering behavior, as shown in Figure 2.20. We can optimize the 
a-divergence by using the power EP method of [Min04]. 

Algorithmically, this is a fairly small modification to regular EP. In particular, we first compute the 


cavity distribution, q2 ox on we then approximate the tilted distribution, qP" = proj(q°*""” f2); 


a 


_\ 1/a 
~ Proj 
and finally we compute the new factor fP°’ « (E=) , 


10.7.6 Stochastic EP 


The main disadvantage of EP in the big data setting is that we need to store the fa (0) terms for 
each datapoint n, so we can compute the cavity distribution. If 0 has D dimensions, and we use full 
covariance Gaussians, this requires O(N D?) memory. 

The idea behind stochastic EP [LHLT15] is to approximate the local factors with a shared factor 
that acts like an aggregated likelihood, i.e., 


fn(8) = F(O)% (10.185) 


= 


n=1 


29 where typically f,(@) = p(æn|0). This exploits the fact that the posterior only cares about approxi- 
39 mating the product of the likelihoods, rather than each likelihood separately. Hence it suffices for 
31 (0) to approximate the average likelihood. 


We can modify EP to this setting as follows. First, when computing the cavity distribution, we 


ee use 


q-1(9) x 4(8)/f(8) (10.186) 
We then compute the tilted distribution 
qQ\n(9) x fn(8)q-1(9) (10.187) 


~ Next we derive the new local factor for this datapoint using moment matching: 


fn(@) = proj(qn(9))/q-1(8) (10.188) 


44 Finall we perform a damped update of the average likelihood f (0) using this new local factor: 


faew(0) — Foa lO) TEAN fa (ON (10.189) 
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10.7. EXPECTATION PROPAGATION (EP) 


The ADF algorithm is similar to SEP, in that we compute the tilted distribution q; « ftqt-ı and 
then project it, without needing to keep the f factors. The difference is that instead of using the 
cavity distribution g_1(@) as prior, it uses the posterior from the previous time step, q—1. This 
avoids the need to compute and store f , but results in overconfidence in the batch setting. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1 1 Monte Carlo inference 


11.1 Introduction 


In this chapter, we discuss Monte Carlo methods, which are a stochastic approach to solving 
numerical integration problems. The name refers to the “Monte Carlo” casino in Monaco; this was 
used as a codename by von Neumann and Ulam, who invented the technique while working on 
the atomic bomb during WWII. Since then, the technique has become widely adopted in physics, 
statistics, machine learning, and many areas of science and engineering. 

In this chapter, we give a brief introduction to some key concepts. In Chapter 12, we discuss 
MCMC, which is the most widely used MC method for high-dimensional problems. In Chapter 13, 
we discuss SMC, which is widely used for MC inference in state space models, but can also be applied 
more generally. For more details on MC methods, see e.g., [Liu01; RC04; KTB11; BZ20]. 


11.2 Monte Carlo integration 


We often want to compute the expected value of some function of a random variable, E[f(X)]. This 
requires computing the following integral: 


[Fæ] = f t@)pla)de (1a) 


where x € R”, f : R” — R”, and p(z) is the target distribution of X.! In low dimensions (up to, say, 
3), we can compute the above integral efficiently using numerical integration, which (adaptively) 
compute a grid, and then evaluate the function at each point on the grid.” But this does not scale to 
higher dimensions. 

An alternative approach is to draw multiple random samples, £n ~ p(a), and then to compute 


Ns 
E) © 5 Ye Fen) (11.2) 


This is called Monte Carlo integration. It has the advantage over numerical integration that 
the function is only evaluated in places where there is non-negligible probability, so it does not 


1. In many cases, the target distribution may be the posterior p(a|y), which can be hard to compute; in such problems, 
we often work with the unnormalized distribution, p(x) = p(a,y), instead, and then normalize the results using 


Z = J p(x, y)dx = ply). 
2. In 1d, numerical integration is called quadrature; in higher dimensions, it is called cubature [Sar13]. 
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Figure 11.1: Estimating n by Monte Carlo integration using 5000 samples. Blue points are inside the circle, 
red points are outside. Generated by mc_ estimate pt.ipynb. 


need to uniformly cover the entire space. In particular, it can be shown that the accuracy is in 
principle independent of the dimensionality of x, and only depends on the number of samples N, 
(see Section 11.2.2 for details). The catch is that we need a way to generate the samples x, ~ p(x) 
in the first place. In addition, the estimator may have high variance. We will discuss this topic at 
length in the sections below. 


11.2.1 Example: estimating m by Monte Carlo integration 


MC integration can be used for many applications, not just in ML and statistics. For example, 
suppose we want to estimate 7. We know that the area of a circle with radius r is rr”, but it is also 
equal to the following definite integral: 


I= 1 Í I (£? +y? < r°) dedy (11.3) 


29 Hence m = I/(r?). Let us approximate this by Monte Carlo integration. Let f(x,y) = I(x? + y? < r°) 
30 be an indicator function that is 1 for points inside the circle, and 0 outside, and let p(x) and p(y) be 
31 uniform distributions on [—r,r], so p(x) = p(y) = 1/(2r). Then 


r= 02r) | | F(c.yple)plu)andy (11.4) 
= 4? | | f0,y)pla)p(y)aeay (11.5) 


N, 
1 s 
S n=1 


= Using 5000 samples, we find 7 = 3.10 with standard error 0.09 compared to the true value of 7 = 3.14. 


We can plot the points that are accepted or rejected as in Figure 11.1. 


= 11.2.2 Accuracy of Monte Carlo integration 


45 The accuracy of an MC approximation increases with sample size. This is illustrated in Figure 11.2. 
46 On the top line, we plot a histogram of samples from a Gaussian distribution. On the bottom line, 
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n_samples = 10 n_samples = 100 


2.0 2.0 


MM samples 
—— true pdf 


ME samples 
1.5 4 — true pdf 
== estimated pdf 


1.5 
== estimated pdf 


1.0 
0.5 
0.0 i 
0 1 2 3 0 1 2 3 
(a) (b) 


Figure 11.2: 10 and 100 samples from a Gaussian distribution, N (u = 1.5,0° = 0.25). A dotted red line 
denotes kernel density estimate derived from the samples. Generated by mc_ accuracy demo.ipynb. 


we plot a smoothed version of these samples, created using a kernel density estimate. This smoothed 
distribution is then evaluated on a dense grid of points and plotted. Note that this smoothing is just 
for the purposes of plotting, it is not used for the Monte Carlo estimate itself. 

If we denote the exact mean by u = E [f(X )], and the MC approximation by Â, one can show that, 
with independent samples, 


2 


(â — u) > N (0, "a (11.7) 
where 
o? = V[f(X)] = E [f(X)?] - EIX)? (11.8) 


This is a consequence of the central limit theorem. Of course, o? 


but it can be estimated by MC: 


is unknown in the above expression, 


Ns 
= Hy (Fen) = a? (11.9) 


n=1 


Thus for large enough N, we have 


é ô 
P} ù—1.96— < u< +1.96- $ x 0.95 11.10 
fa Jn, "SÊ \ ( ) 


The term (# is called the (numerical or empirical) standard error, and is an estimate of our 
uncertainty about our estimate of u. 

If we want to report an answer which is accurate to within te with probability at least 95%, we 
need to use a number of samples Ns which satisfies 1.964/6? /Ns < e. We can approximate the 1.96 
factor by 2, yielding N, > a 

The remarkable thing to note about the above results is that the error in the estimate, o? /N,, is 
theoretically independent of the dimensionality of the integral. The catch is that sampling from high 
dimensional distributions can be hard. We turn to that topic next. 
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Inverse cdf sampling 


1.0 
a cdf-N (3, 1) 
e e samples Unif(0,1) 
F 0.5 e samples N (3,1) 
Š pdf (3,1) 
0.0 
0 2 4 6 


N(3, 1) 


Figure 11.3: Sampling from N(3,1) using an inverse CDF. 


11.3 Generating random samples from simple distributions 


We saw in Section 11.2 how we can evaluate E [f(X)] for different functions f of a random variable X 
using Monte Carlo integration. The main computational challenge is to efficiently generate samples 
from the probability distribution p*(a) (which may be a posterior, p*(a) x p(a|D)). In this section, 
we discuss sampling methods that are suitable for parametric univariate distributions. These can be 
used as building blocks for sampling from more complex multivariate distributions. 


11.3.1 Sampling using the inverse cdf 


The simplest method for sampling from a univariate distribution is based on the inverse probability 
transform. Let F be a cdf of some distribution we want to sample from, and let F71 be its inverse. 


26 Then we have the following result. 
— Theorem 11.3.1. If U ~ U(0,1) is a uniform rv, then F~\(U) ~ F. 
29 Proof. 


Pr(F~'(U) < x) =Pr(U < F(x)) (applying F to both sides) (11.11) 
= F(x) (because Pr(U < y) = y) (11.12) 


~ where the first line follows since F is a monotonic function, and the second line follows since U is 


e Je Je Je Je Jè Jœ jœ jw [w jw jw 
SIE IS JB 16 18 1S 18 IS IS I3 Ik 


uniform on the unit interval. 


Hence we can sample from any univariate distribution, for which we can evaluate its inverse cdf, as 


37 follows: generate a random number u ~ U(0,1) using a pseudo random number generator (see 
38 e.g., [Pre+88] for details). Let u represent the height up the y axis. Then “slide along” the x axis 
39 until you intersect the F curve, and then “drop down” and return the corresponding x value. This 
40 corresponds to computing x = F~! (u). See Figure 11.3 for an illustration. 


For example, consider the exponential distribution 


Expon(a|\) £ Ae~** I(x > 0) (11.13) 
— The cdf is 
F(x) =1-—e-** I(x > 0) (11.14) 


A Iè | 
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11.4. REJECTION SAMPLING 


whose inverse is the quantile function 


In(1 — 

F(p) = e (11.15) 
By the above theorem, if U ~ Unif (0,1), we know that F~'(U) ~ Expon(A). So we can sample from 
the exponential distribution by first sampling from the uniform and then transforming the results 
using — ln(1 — u)/A. (In fact, since 1 — U ~ Unif(0, 1), we can just use — In(w)/A.) 


11.3.2 Sampling from a Gaussian (Box-Muller method) 


In this section, we describe a method to sample from a Gaussian. The idea is we sample uniformly 
from a unit radius circle, and then use the change of variables formula to derive samples from a 
spherical 2d Gaussian. This can be thought of as two samples from a 1d Gaussian. 

In more detail, sample z1, 22 € (—1, 1) uniformly, and then discard pairs that do not satisfy z? +22 < 
1. The result will be points uniformly distributed inside the unit circle, so p(z) = +I (z inside circle). 
Now define 


1 
—2]nr?\ ? 
ti sal 2 ) (11.16) 


r 


2 


for i = 1 : 2, where r? = 2? + z2. Using the multivariate change of variables formula, we have 


= O(21, 22) = 1 1 2 1 1 2 
plen 22) = plen 2) EE = | Jo z] | a z| (11.17) 


Hence xı and x2 are two independent samples from a univariate Gaussian. This is known as the 
Box-Muller method. 

To sample from a multivariate Gaussian, we first compute the Cholesky decomposition of its 
covariance matrix, © = LL', where L is lower triangular. Next we sample æ ~ N (0, I) using the 
Box-Muller method. Finally we set y = La + yp. This is valid since 


Cov [y] = LCov [a] L! = LI L! = X (11.18) 


11.4 Rejection sampling 
Suppose we want to sample from the target distribution 

p(w) = ple)/Zp (11.19) 
where p(x) is the unnormalized version, and 


Z = f xæ) dæ (11.20) 


is the (possibly unknown) normalization constant. One of the simplest approaches to this problem is 
rejection sampling, which we now explain. 
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ti) er 
„Mal ) 1.0 i `, 
I x *e 
1 P J 
+ oxy 2 target p(x) “4, 
0.5 Si EE comparison "Sy 


uMa(s"} “function Mq(x) 


T 
Accept Region + Reject Region 
i ` 


x”~ q(x) 0 2 4 6 8 10 
(a) (b) 


Figure 11.4: (a) Schematic illustration of rejection sampling. From Figure 2 of [And+03]. Used with kind 
permission of Nando de Freitas. (b) Rejection sampling from a Ga(a = 5.7, A = 2) distribution (solid blue) 
using a proposal of the form MGa(k, A—1) (dotted red), where k = |5.7| = 5. The curves touch at a—k = 0.7. 
Generated by rejection_ sampling_ demo.ipynb. 


11.4.1 Basic idea 


In rejection sampling, we require access to a proposal distribution q(x) which satisfies Cq(x) > 
p(x), for some constant C. The function Cq(x) provides an upper envelope for p. 

We can use the proposal distribution to generate samples from the target distribution as follows. 
We first sample ao ~ q(x), which corresponds to picking a random g location, and then we sample 
uo ~ Unif (0, Cq(ao)), which corresponds to picking a random height (y location) under the envelope. 
If uo > p(x), we reject the sample, otherwise we accept it. This process is illustrated in 1d in 
Figure 11.4(a): the acceptance region is shown shaded, and the rejection region is the white region 
between the shaded zone and the upper envelope. 

We now prove this procedure is correct. First note that the probability of any given sample £o 
being accepted equals the probability of a sample uo ~ Unif(0,C¢(xo)) being less than or equal to 
(xo), i.e., 


Aso i P(x0) 
accept|xo) = du = 11.21 
decowtiee)= | oa = ae An 
Therefore 
q(propose and accept £o) = q(£o)qlaccept|xo) = g(a) P(zo) = da (11.22) 
Cq(xo) C 


Integrating both sides give 


J e(ao)a(accept|aro) dzo = q(accept) = T Bleo) den = Zp (11.23) 


C C 


Hence we see that the distribution of accepted points is given by the target distribution: 


q(£o, accept)  p(ag) C p(x) 
a(aolaccept) = Sr = = Pp = pleo) (11.24) 
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11.4. REJECTION SAMPLING 


How efficient is this method? If p is a normalized target distribution, the acceptance probability is 
1/C. Hence we want to choose C as small as possible while still satisfying Cq(x) > p(x). 


11.4.2 Example 


For example, suppose we want to sample from a Gamma distribution:® 


Ga(zla, A) = ge?" \” exp(—Az) (11.25) 


Ta) 


where T (a) is the gamma function. One can show that if X; a Expon(A), and Y = X, +: + Xz, 
then Y ~ Ga(k, A). For non-integer shape parameters a, we cannot use this trick. However, we can 
use rejection sampling using a Ga(k, A — 1) distribution as a proposal, where k = |a|. The ratio has 
the form 


plz) — Ga(zla,A) x271% exp(—Azx) /T(a) (11.26) 
q(x)  Ga(z|k,A— 1) wk-1(X — 1)¥ exp(—(A — 1)x)/T (k) ` 
Z T(k)A° a—k 
This ratio attains its maximum when z = a— k. Hence 
Ga(a — ka, A) (11.28) 


~ Ga(a — klk, A— 1) 


See Figure 11.4(b) for a plot. 


11.4.3 Adaptive rejection sampling 


We now describe a method that can automatically come up with a tight upper envelope q(x) to 
any log concave 1d density p(x). The idea is to upper bound the log density with a piecewise linear 
function, as illustrated in Figure 11.5(a). We choose the initial locations for the pieces based on a 
fixed grid over the support of the distribution. We then evaluate the gradient of the log density at 
these locations, and make the lines be tangent at these points. 

Since the log of the envelope is piecewise linear, the envelope itself is piecewise exponential: 


q(x) = CiAvexp(—Ai(z — Bi-1)), Ti-1 < T < Ti (11.29) 


where x; are the grid points. It is relatively straightforward to sample from this distribution. If the 
sample x is rejected, we create a new grid point at x, and thereby refine the envelope. As the number 
of grid points is increased, the tightness of the envelope improves, and the rejection rate goes down. 
This is known as adaptive rejection sampling (ARS) [GW92]. Figure 11.5(b-c) gives an example 
of the method in action. As with standard rejection sampling, it can be applied to unnormalized 
distributions. 


3. This section is based on notes by Ioana A. Cosma, available at http://users.aims.ac.za/~ioana/cp2.pdf. 
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f(x) half-gaussian samples from f(x) (by ARS) 
Į: 
0 1000 
r g 
8 6, 
n 
0.0 0 
—5 0 5 —5 0 5 
x 
(a) (b) (c) 


Figure 11.5: (a) Idea behind adaptive rejection sampling. We place piecewise linear upper (and lower) bounds 
on the log-concave density. Adapted from Figure 1 of [GW92]. Generated by ars_envelope.ipynb. (b-c) Using 
ARS to sample from a half-Gaussian. Generated by ars_demo.ipynb. 


11.4.4 Rejection sampling in high dimensions 


It is clear that we want to make our proposal q(x) as close as possible to the target distribution p(æ), 
while still being an upper bound. But this is quite hard to achieve, especially in high dimensions. To 
see this, consider sampling from p(æ) = M (0, oŽI) using as a proposal q(a) = N (0, 071). Obviously 
we must have o? a a; in order to be an upper bound. In D dimensions, the optimum value is given 
by C = (a,/op)”. The acceptance rate is 1/C (since both p and q are normalized), which decreases 
exponentially fast with dimension. For example, if ag exceeds øp by just 1%, then in 1000 dimensions 
the acceptance ratio will be about 1/20,000. This is a fundamental weakness of rejection sampling. 


2 11.5 Importance sampling 


30 In this section, we describe a Monte Carlo method known as importance sampling for approxi- 
31 mating integrals of the form 


:[p(«)] = I p(æ)r(æ)dæ (11.30) 


where ọ is called a target function, and (ax) is the target distribution, often a conditional 
distribution of the form q(x) = p(aly). Since in general it is difficult to draw from the target 
distribution, we will instead draw from some proposal distribution q(x) (which will usually 
depend on y). We then adjust for the inaccuracies of this by associating weights with each sample, 
so we end up with a weighted MC approximation: 


N 


[olL] ~ X` Wry(aen) (11.31) 


n=1 


45 We discuss two cases, first when the target is normalized, and then when it is unnormalized. This 
46 will affect the ways the weights are computed, as well as statistical properties of the estimator. 
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11.5. IMPORTANCE SAMPLING 


11.5.1 Direct importance sampling 


In this section, we assume that we can evaluate the normalized target distribution m(x), but we 
cannot sample from it. So instead we will sample from the proposal g(a). We can then write 


J| e@)nw)ae = Joo E aas (11.32) 


q(x) 


We require that the proposal be non-zero whenever the target is non-zero, i.e., the support of q(x) 
needs to be greater or equal to the support of (a). If we draw Ns samples x, ~ q(x), we can write 


= 
8 
— 


(ee 
(an) = Fp X tiny(an) (11.33) 


(11.34) 


The result is an unbiased estimate of the true mean E [y(a)]. 


11.5.2 Self-normalized importance sampling 


The disadvantage of direct importance sampling is that we need a way to evaluate the normalized 
target distribution m in order to compute the weights. It is often much easier to evaluate the 
unnormalized target distribution 


A(x) = Zn(x) (11.35) 


where 
Z= [iedz (11.36) 


is the normalization constant. (For example, if 7(a) = p(æ|y), then 4 (x) = p(x, y) and Z = p(y).) 
The key idea is to also approximate the normalization constant Z with importance sampling. This 
method is called self-normalized importance sampling. The resulting estimate is a ratio of 
two estimates, and hence is biased. However as N, — oo, the bias goes to zero, under some weak 
assumptions (see e.g., [RC04] for details). 

In more detail, SNIS is based on this approximation: 


(x)y(a)da S | aay e(@)) a(w)dx 
f oantein = OSS = s J (11.37) 


i [p(æ)] 


Ns ~ 
RA Yndi Üny(En) 


Ns ~ 
wr paar Wn 


Q 


(11.38) 
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where we have defined the unnormalized weights 


ES lEn) (11.39) 
q(&n) 
We can write Equation (11.38) more compactly as 
Ns 
[p(x] = X` Waylan) (11.40) 
n=1 
where we have defined the normalized weights by 
Wn 
Wn = ——_—_ 11.41 
ae Wn! 


This is equivalent to approximating the target distribution using a weighted sum of delta functions: 
Ns 
n(x) ~ X` W,6(a@ — £n) £ #(x) (11.42) 
n=1 


As a byproduct of this algorithm we get the following appoximation to the normalization constant: 
eee 
Zx—) ti, 22 11.43 
N, > Wn ( ) 


11.5.3 Choosing the proposal 


28 The performance of importance sampling depends crucially on the quality of the proposal distribution. 
23 As we mentioned, we require that the support of q cover the support of the target (i.e., 7(a) > 0 => 
3° g(x) > 0). However, we also want the proposal to not be too “loose” of a “covering”. Ideally it should 
32 also take into account properties of the target function y as well, as shown in Figure 11.6. This can 
32 yield subsantial benefits, as shown in the “target aware Bayesian inference” scheme of [Rai+20]. 
28 However, usually the target function y is unknown or ignored, so we just try to find a “generally 
34 useful’ approximation to the target. 


One way to come up with a good proposal is to learn one, by optimizing the variational lower 


36 bound or ELBO (see Section 10.1.2). Indeed, if we fix the parameters of the generative model, we 
37 can think of importance weighted autoencoders (Section 10.5.1) as learning a good IS proposal. More 
38 details on this connection can be found in [DS18]. 


~ 11.5.4 Annealed importance sampling (AIS) 


In this section, we describe a method known as annealed importance sampling [Nea01] for 


43 sampling from complex, possibly multimodal distributions. Assume we want to sample from some 
44 target distribution po(x) x fo(x) (where f(a) is the unnormalized version), but we cannot easily do 
45 so, because po is complicated in some way (e.g., high dimensional and/or multi-modal). However, 
46 suppose that there is an easier distribution which we can sample from, call it p,(a) x fn(a); for 
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11.5. IMPORTANCE SAMPLING 


Figure 11.6: In importance sampling, we should sample from a distribution that takes into account regions 
where n(x) has high probability and where y(a) is large. Here the function to be evaluated is an indicator 
function of a set, corresponding to a set of rare events in the tail of the distribution. From Figure 3 of 
[And+03]. Used with kind permission of Nando de Freitas. 


example, this might be the prior. We now construct a sequence of intermediate distributions than 
move slowly from prn to po as follows: 


fia) = fol)” fr(a)'-% (11.44) 


where 1 = fo > bı >+- > n = 0, where 8; is an inverse temperature. We will sample a set of 
points from fn, and then from f,—1, and so on, until we eventually sample from fo. 

To sample from each fj, suppose we can define a Markov chain Tj (a, x’) = p;(x' |x), which leaves 
po invariant (i.e., f p;(x’|ax)po(x)dx = po(x’)). (See Chapter 12 for details on how to construct such 
chains.) Given this, we can sample a from po as follows: sample Un ~ pn; sample vp—1 ~ Tr-1(Un,°); 
and continue in this way until we sample vg ~ Tp(v1,-); finally we set x = vo and give it weight 


S fa-1(Vn-1) fn—2(Un-2) flr) fo(vo) 
fn(Un—1) fn—-1(Un—2) f2lvı) fi(wo) 


This can be shown to be correct by viewing the algorithm as a form of importance sampling in an 


(11.45) 


extended state space v = (vo, ..., Un). Consider the following distribution on this state space: 
plv) x (v) = fo(vo)To(vo, v1) T2(v1, v2) i: TA (OA Vn) (11.46) 
x p(vo)p(v1|V0) *-- P(Yn|¥n—1) (11.47) 


where T; is the reversal of T}: 


Ty(v,v') = Tj’, v)pj(v')/pj(v) = Tjo, v) f)(v")/F4(0) (11.48) 


It is clear that $, 
po(@). 


v1,...,.v, P(X) = fo(vo), so by sampling from p(v), we can effectively sample from 
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We can sample on this extended state space using the above algorithm, which corresponds to the 
following proposal: 


q(v) x glv) = fr(Un)In—1(Un, Un—1) ++» T2(v2, v1) To(v1, vo) (11.49) 

x P(Un)P(Un—1|Un) re - p(v1|Vo) (11.50) 
Pre p(vo,---,Un) 
~~ g(¥o0,-+;Un) 
marginals of the sampled sequences from this extended model are equivalent to samples from po(x), 
we see that we are using the correct weights. 


One can show that the importance weights w are given by Equation (11.45). Since 


11.5.4.1 Estimating normalizing constants using AIS 


An important application of AIS is to evaluate a ratio of partition functions. Notice that Zo = 
J fo(a)dx = f y(v)dv, and Zn = f fr(w)dx = f g(v)dv. Hence 


a) (11.51) 


Zo f[elvjdv SRy Tow) viN 
: ( e| Jes 


where ws = y(vs)/g(vs). If fo is a prior and fn is the posterior, we can estimate Zn = p(D) using 
the above equation, provided the prior has a known normalization constant Zp. This is generally 
considered the method of choice for evaluating difficult partition functions. 


11.6 Controlling Monte Carlo variance 


= As we mentioned in Section 11.2.2, the standard error in a Monte Carlo estimate is O(1/VS), where 
= § is the number of (independent) samples. Consequently it may take many samples to reduce the 


= variance to a sufficiently small value. In this section, we discuss some ways to reduce the variance of 


FIR IRI 18 IB JS IS I$ [I$ I8 IS I< IS I$ JE IS 
NID [OU TR [we IN Ie IO lo Jo IN [om [or TR Jw [NS | [oO 


— sampling methods. For more details, see e.g., |KTB11]. 


11.6.1 Common random numbers 


34 When performing Monte Carlo optimization, we often want to compare E,,z) [f(0, z)] to Epez) [f(@’, 2)] 
35 for different values of the parameters 0 and 0’. To reduce the variance of this comparison, we can 
2° use the same random samples zs for evaluating both functions. In this way, differences in the 
2" outcome can be ascribed to differences in the parameters 0, rather than to the noise terms. This 
38 is called the common random numbers trick, and is widely used in ML (see e.g., [GBJ18; 
39 NJOO]), since it can often convert a stochastic optimization problem into a deterministic one, 
W enabling the us of more powerful optimization methods. For more details on CRN, see e.g., 
= https://en.wikipedia. org/wiki/Variance_reduction#Common_Random_Numbers_ (CRN). 


~= 11.6.2 Rao-Blackwellisation 


45 In this section, we discuss a useful technique for reducing the variance of MC estimators known as 
46 Rao-Blackwellisation. To explain the method, suppose we have two rv’s, X and Y, and we want 
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11.6. CONTROLLING MONTE CARLO VARIANCE 


to estimate f =E[f(X,Y)]. The naive approach is to use an MC approximation 


fuc = 


5 
sia) (11.52) 


Ul = 


where (Xs, Ys) ~ p(X, Y). This is an unbiased estimator of f. However, it may have high variance. 
Now suppose we can analytically marginalize out Y, provided we know X, i.e., we can tractably 
compute 


X) = f d¥ p(X.) F(%.,¥) = BLK YIX = Xi (11.53) 


Let us define the Rao-Blackwellised estimator 


fre = 


S 
Sfx (Xs) (11.54) 


U| = 


where X, ~ p(X). This is an unbiased estimator, since E | frg | = E [E [f (X,Y )|X]] = f. However, 
this estimate can have lower variance than the naive estimator. The intuitive reason is that we are 
now sampling in a reduced dimensional space. Formally we can see this by using the law of iterated 
variance to get 


VEX YIX] = V(X, Y)] -E [V LAX YX] < V [F(X,Y] (11.55) 


For some examples of this in practice, see Section 6.5.3.2, Section 13.4, and Section 12.3.8. 


11.6.3 Control variates 


Suppose we want to estimate u = E[f(X)] using an unbiased estimator m(4’) = D m(zs), 


where zs ~ p(X) and E [m(X)] = u. (We abuse notation slightly and use m to refer to a function of 
a single random variable as well as a set of samples.) Now consider the alternative estimator 


m* (X) = m(X) + c (B(X) — E [B(X] (11.56) 


This is called a control variate, and b is called a baseline. (Once again we abuse notation and use 
D(X) = 3 Zs- D(x) and m* (X) = § Ze- M” (2s)-) 

It is easy to see that m*(4) is an unbiased estimator, since E [m*(X)] = E[m(X)] = u. However, 
it can have lower variance, provided b is correlated with m. To see this, note that 


V [m*(X)] = V[m(X)] + êV [b(X)] + 2cCov [m(X), b(X)] (11.57) 
By taking the derivative of V [m*(X)] wrt c and setting to 0, we find that the optimal value is 


» _ _ Cov [m(X), 0(X)] 
c= ALES) (11.58) 
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The corresponding variance of the new estimator is now 


Cov [m(X), (X)? 
V [b(X)] 


where p2, is the correlation of the basic estimator and the baseline function. If we can ensure 
this correlation is high, we can reduce the variance. Intuitively, the CV estimator is exploiting 
information about the errors in the estimate of a known quantity, namely E [b(X)], to reduce the 
errors in estimating the unknown quantity, namely p. 

We give a simple worked example in Section 11.6.3.1. See Section 10.3.2 for an example of this 
technique applied to blackbox variational inference. 


V[m*(X)] = V[m(X)] 


= (1 = pine) V [m(X)] < V[m(X)] (11.59) 


11.6.3.1 Example 


We now give a simple worked example of control variates.* Consider estimating u = E[f(X)] where 
f(X) =1/(1+ X) and X ~ Unif(0,1). The exact value is 


1 
1 
= dx = ln 2 = 0.693 11.60 
j f l+z nae ( ) 


The naive MC estimate, using S samples, is m(4’) = 15 f(a,). Using S = 1500, we find 
2 [m(&)] = 0.6935 with standard error se = 0.0037. 
Now let us use b(X) = 1+ X as a baseline, so b(¥) = (1/5) >°,(1+2,). This has expectation 


E [b(X)] = ha + x)dx = 3. The control variate estimator is given by 


: he r€ 3 
m* (æ) = 3 of (#s) +e 5 2 ts) - 5 (11.61) 


~ The optimal value can be estimated from the samples of m(x.) and 6(#,), and plugging into 
~~ Equation (11.58) to get c* ~ 0.4773. Using S = 1500, we find E [m*(%)] = 0.6941 and se = 0.0007. 


See also Section 11.6.4.1, where we analyze this example using antithetic sampling. 


a, 11.6.4 Antithetic sampling 


In this section, we discuss antithetic sampling, which is a simple way to reduce variance.” Suppose 


x 


34 we want to estimate 0 = E [Y]. Let Yı and Y> be two samples. An unbiased estimate of 0 is given by 
= 0 = (Yı + Y2)/2. The variance of this estimate is 


y jä _VMm)+V Liis 2Cov [Y1, Ya] (11.62) 


39 so the variance is reduced if Cov [Y,, Y2] < 0. So whenever we sample Y, we should set Y> to be its 


“opposite”, but with the same mean. 
For example, suppose Y ~ Unif(0,1). If we let y1,..., Yn be iid samples from Unif(0, 1), then we 


42 can define y; = 1 — yi. The distribution of y; is still Unif(0,1), but Cov [y;, yi] < 1. 


4. The example is from https://en.wikipedia.org/wiki/Control_variates, with modified notation. See con- 


— trol variates.ipynb for some code. 
=2 5. Our presentation is based on https://en.wikipedia.org/wiki/Antithetic_variates. See anti- 


A |e [A 


46 thetic_sampling.ipynb for the code. 
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11.6. CONTROLLING MONTE CARLO VARIANCE 


512 MC 512 QMC 512 RQMC 


Unit square Unit square Unit square 


Figure 11.7: Illustration of Monte Carlo (MC), Quasi MC (QMC) from a Sobol sequence, and Randomized 
QMC using a scrambling method. Adapted from Figure 1 of [OR20]. Used with kind permission of Art Owen. 


11.6.4.1 Example 


To see why this can be useful, consider the example from Section 11.6.3.1. Let fim, be the classic MC 
estimate using 2N samples from Unif(0, 1), and let fant; be the MC estimate using the above antithetic 
sampling scheme applied to N base samples from Unif(0,1). The exact value is u = ln 2 ~ 0.6935. 
For the classical method, with N = 750, we find E [fim-] = 0.69365 with a standard error of 0.0037. 
For the antithetic method, we find E [fanti] = 0.6939 with a standard error of 0.0007, which matches 
the control variate method of Section 11.6.3.1. 


11.6.5 Quasi Monte Carlo (QMC) 


Quasi Monte Carlo (see e.g., [Lem09; Owe13]) is an approach to numerical integration that replaces 
random samples with low discrepancy sequences, such as the Halton sequence (see e.g., [Owe17]) 
or Sobol sequence. Intuitively, these are space filling sequences of points, constructed to reduce 
the unwanted gaps and clusters that would arise among randomly chosen inputs. See Figure 11.7 for 
an example.® 

More precisely, consider the problem of evaluating the following D-dimensional integral: 


f= - fle)de ~ fy = Do Flan) (11.63) 


Let en = |f — fyl be the error. In standard Monte Carlo, if we draw N independent samples, then 


we have ey ~ O (A) In QMC, it can be shown that ey ~ O (e22). For N > 2?, the latter 


is smaller than the former. 

One disadvantage of QMC is that it just provides a point estimate of f, and does not give an 
uncertainty estimate. By contrast, in regular MC, we can estimate the MC standard error, discussed 
in Section 11.2.2. Randomized QMC (see e.g., [L’E18]) provides a solution to this problem. The 
basic idea is to repeat the QMC method R times, by perturbing the sequence of N points by a 


6. More details on QMC can be found at http: //roth.cs.kuleuven.be/wiki/Main_Page. For connections to Bayesian 
quadature, see e.g., [DKS13; HKO22]. 
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random amount. In particular, define 
Yi,r = Ti + Ur (mod 1) (11.64) 


where 21,...,@y is a low-discrepancy sequence, and up ~ Unif(0,1)? is a random perturbation. 
The set {y;} is low discrepancy, and satisfies that each y; ~ Unif(0,1)?, for j = 1 : N x R. This 
has much lower variance than standard MC. (Typically we take R to be a power of 2.) Recently, 
[OR20] proved a strong law of large numbers for RQMC. 

QMC and RQMC can be used inside of MCMC inference (see e.g., [OT05]) and variational inference 
(see e.g., [BWM18]). It is also commonly used to select the initial set of query points for Bayesian 
optimization (Section 6.8). 

Another technique that can be used is orthogonal Monte Carlo, where the samples are condi- 
tioned to be pairwise orthogonal, but with the marginal distributions matching the original ones (see 
e.g., [Lin+20]). 
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1 2 Markov Chain Monte Carlo inference 


12.1 Introduction 


In Chapter 11, we considered non-iterative Monte Carlo methods, including rejection sampling and 
importance sampling, which generate independent samples from some target distribution. The trouble 
with these methods is that they often do not work well in high dimensional spaces. In this chapter, 
we discuss a popular method for sampling from high-dimensional distributions known as Markov 
chain Monte Carlo or MCMC. In a survey by SIAM Newst, MCMC was placed in the top 10 
most important algorithms of the 20th century. 

The basic idea behind MCMC is to construct a Markov chain (Section 2.6) on the state space X 
whose stationary distribution is the target density p*(x) of interest. (In a Bayesian context, this 
is usually a posterior, p*(a) x p(a|D), but MCMC can be applied to generate samples from any 
kind of distribution.) That is, we perform a random walk on the state space, in such a way that the 
fraction of time we spend in each state x is proportional to p* (æ). By drawing (correlated) samples 
£o, £1, £2, ..., from the chain, we can perform Monte Carlo integration wrt p*. 

Note that the initial samples from the chain do not come from the stationary distribution, and 
should be discarded; the amount of time it takes to reach stationarity is called the mixing time or 
burn-in time; reducing this is one of the most important factors in making the algorithm fast, as 
we will see. 

The MCMC algorithm has an interesting history. It was discovered by physicists working on the 
atomic bomb at Los Alamos during World War II, and was first published in the open literature in 
[Met-+53] in a chemistry journal. An extension was published in the statistics literature in [Has70], 
but was largely unnoticed. A special case (Gibbs sampling, Section 12.3) was independently invented 
in [GG84] in the context of Ising models (Section 4.3.2.1). But it was not until [GS90] that the 
algorithm became well-known to the wider statistical community. Since then it has become wildly 
popular in Bayesian statistics, and is becoming increasingly popular in machine learning. 

In the rest of this chapter, we give a brief introduction to MCMC methods. For more details on 
the theory, see e.g., [GRS96; BZ20]. For more details on the implementation side, see e.g., [Lao+20]. 
And for an interactive visualization of many of these algorithsm in 2d, see http: //chi-feng.github. 
io/mcmc-demo/app. html. 


1. Source: http://www.siam.org/pdf/news/637 . pdf. 
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12.2 Metropolis Hastings algorithm 


In this section, we describe the simplest kinds of MCMC algorithm known as the Metropolis 
Hastings or MH algorithm. 


12.2.1 Basic idea 


The basic idea in MH is that at each step, we propose to move from the current state x to a new state 
x’ with probability q(a’|a), where q is called the proposal distribution (also called the kernel). 
The user is free to use any kind of proposal they want, subject to some conditions which we explain 
below. This makes MH quite a flexible method. 

Having proposed a move to x’, we then decide whether to accept this proposal, or to reject it, 
according to some formula, which ensures that the long-term fraction of time spent in each state is 
proportional to p*(ax). If the proposal is accepted, the new state is x’, otherwise the new state is the 
same as the current state, x (i.e., we repeat the sample). 

If the proposal is symmetric, so q(a’|x) = q(a|x’), the acceptance probability is given by the 
following formula: 


* / 
A=min (i BAe L) (12.1) 
p* (æ) 

We see that if æ’ is more probable than æ, we definitely move there (since = an > 1), but if a’ is 
less probable, we may still move there anyway, depending on the relative probabilities. So instead of 
greedily moving to only more probable states, we occasionally allow “downhill” moves to less probable 
states. In Section 12.2.2, we prove that this procedure ensures that the fraction of time we spend in 
each state x is equal to p* (x). 

If the proposal is asymmetric, so q(x’|a) 4 q(a|x’), we need the Hastings correction, given by 
the following: 


A = min(1,a) (12.2) 
_ p*(a')a(ala') _ p*(a")/q(a"'|a) 
p*(x)q(a'|x) — p*(a)/q(a]a’) 


(12.3) 


33 This correction is needed to compensate for the fact that the proposal distribution itself (rather than 
34 just the target distribution) might favor certain states. 


An important reason why MH is a useful algorithm is that, when evaluating a, we only need to 


36 know the target density up to a normalization constant. In particular, suppose p* (x) = +p(æ), where 
37 p(x) is an unnormalized distribution and Z is the normalization constant. Then 


_ @(æ')/Z) alex’) 
(P(æ)/Z) q(æ'|æ) 


(12.4) 


41 so the Z’s cancel. Hence we can sample from p* even if Z is unknown. 


A proposal distribution q is valid or admissible if it “covers” the support of the target. Formally, 


43 we can write this as 


supp(p*) C Uzsupp(q(-|x)) (12.5) 


46 With this, we can state the overall algorithm as in Algorithm 26. 
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12.2. METROPOLIS HASTINGS ALGORITHM 


Algorithm 26: Metropolis Hastings algorithm 


Initialize 1° 


1 

2 for s=0,1,2,...do 

3 Define x = z5 

4 Sample 2’ ~ q(x’ |x) 

5 Compute acceptance probability 


P(2")q(a| x") 


** Bo'la) 


6 Compute A = min(1, a) 
Sample u ~ U(0, 1) 
Set new sample to 


sit J x ifu< A (accept) 
x° ifu>A (reject) 


12.2.2 Why MH works 


To prove that the MH procedure generates samples from p*, we need a bit of Markov chain theory, 
as discussed in Section 2.6.4. 
The MH algorithm defines a Markov chain with the following transition matrix: 


q(x’ |a) A(x |x) ifa’ Aa 


p(a'|a) = { qlæjæ) +X arza Ux'|e)(1— A(z'|x)) otherwise (12.6) 


This follows from a case analysis: if you move to x’ from x, you must have proposed it (with 
probability q(a’|x)) and it must have been accepted (with probability A(ax’|a)); otherwise you stay in 
state x, either because that is what you proposed (with probability g(a|a)), or because you proposed 
something else (with probability q(a’|x)) but it was rejected (with probability 1 — A(a’|ax)). 

Let us analyse this Markov chain. Recall that a chain satisfies detailed balance if 


p(x'|x)p* (x) = p(x|x")p* (x’) (12.7) 
This means in the in-flow to state x’ from æ is equal to the out-flow from state x’ back to æ, and vice 
versa. We also showed that if a chain satisfies detailed balance, then p* is its stationary distribution. 
Our goal is to show that the MH algorithm defines a transition function that satisfies detailed balance 
and hence that p* is its stationary distribution. (If Equation (12.7) holds, we say that p* is an 
invariant distribution wrt the Markov transition kernel q.) 

Theorem 12.2.1. If the transition matrix defined by the MH algorithm (given by Equation (12.6)) 
is ergodic and irreducible, then p* is its unique limiting distribution. 


Proof. Consider two states x and x’. Either 
p* (x)q(a' |x) < p*(x’)q(x|x") (12.8) 
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or 
p*(x)q(x'|æ) > p*(x')q(x\x") (12.9) 
Without loss of generality, assume that p*(a)q(a’|x) > p*(x')ą(æ|x'). Hence 
2k / / 
eo es (x yan) 
p* (x)q(a' |x) 
Hence we have A(x'|x) = a(x'|x) and A(ax|a’) = 1. 
Now to move from x to x’ we must first propose x’ and then accept it. Hence 
p*(x')q(@la') _ p(x) 
p(x)q(a'|a)  p*(æ) 


<1 (12.10) 


p(w'|ax) = q(x" |e) A(a'|a) = q(a'|a) q(x|2’) (12.11) 


Hence 

p*(x)p(x' |x) = p*(x')q(2|x’) (12.12) 
The backwards probability is 

pala’) = q(x\x")A(x|x’) = g(a|2’) (12.13) 
since A(a|ax’) = 1. Inserting this into Equation (12.12) we get 

p* (x)p(x" |x) = p*(a’)p(|x’) (12.14) 


so detailed balance holds wrt p*. Hence, from Theorem 2.6.3, p* is a stationary distribution. 


26 Furthermore, from Theorem 2.6.2, this distribution is unique, since the chain is ergodic and irreducible. 


= 12.2.3 Proposal distributions 


31 In this section, we discuss some common proposal distributions. Note, however, that good proposal 


design is often intimately dependent on the form of the target distribution (most often the posterior). 


34 12.2.3.1 Independence sampler 


If we use a proposal of the form q(a’|x”) = q(x’), where the new state is independent of the old 
state, we get a method known as the independence sampler, which is similar to importance 
sampling (Section 11.5). The function q(x’) can be any suitable distribution, such as a Gaussian. 
This has non-zero probability density on the entire state space, and hence is a valid proposal for any 
unconstrained continuous state space. 


42 12.2.3.2 Random walk Metropolis (RWM) algorithm 


== The random walk Metropolis algorithm corresponds to MH with the following proposal distribu- 


= tion: 


IS 16 [à 


q(a' |x) = N(a' |x, 771) (12.15) 
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12.2. METROPOLIS HASTINGS ALGORITHM 


MH with M (0, 1?) proposal 


rie 0.04 


0.03 
0.02 
0.01 
0.00 


100 


50 


0 
-100 °° 5 antes 


1000 


(a) 


Figure 12.1: An example of the Metropolis Hastings algorithm for sampling from a mixture of two 1D 


MH with M (0, 500?) proposal MH with M (0, 8°) proposal 


“oh. 0.03 


0.02 
0.01 
0.00 


s 1 
5 
0 
1000 —50 
-100 7 Samples 


(c) 


Gaussians (u = (—20, 20), m = (0.3,0.7), & = (100, 100)), using a Gaussian proposal with standard deviation 
of T € {1,8,500}. (a) When T = 1, the chain gets trapped near the starting state and fails to sample from 
the mode at p = —20. (b) When T = 500, the chain is very “sticky”, so its effective sample size is low (as 
reflected by the rough histogram approximation at the end). (c) Using a variance of T = 8 is just right and 
leads to a good approximation of the true distribution (shown in red). Compare to Figure 12.4. Generated by 
memc_gmm_ demo.ipynb. 


Here 7 is a scale factor chosen to facilitate rapid mixing. [RRO1b] prove that, if the posterior is 
Gaussian, the asymptotically optimal value is to use 7? = 2.387/D, where D is the dimensionality of 
x; this results in an acceptance rate of 0.234, which (in this case) is the optimal tradeoff between 
exploring widely enough to cover the distribution without being rejected too often. (See [Béd08] for 
a more recent account of optimal acceptance rates for random walk Metropolis methods.) 

Figure 12.1 shows an example where we use RWM to sample from a mixture of two 1D Gaussians. 
This is a somewhat tricky target distribution, since it consists of two somewhat separated modes. 
It is very important to set the variance of the proposal 72 correctly: If the variance is too low, the 
chain will only explore one of the modes, as shown in Figure 12.1(a), but if the variance is too large, 
most of the moves will be rejected, and the chain will be very sticky, i.e., it will stay in the same 
state for a long time. This is evident from the long stretches of repeated values in Figure 12.1(b). 
If we set the proposal’s variance just right, we get the trace in Figure 12.1(c), where the samples 
clearly explore the support of the target distribution. 


12.2.3.3 Composing proposals 


If there are several proposals that might be useful, one can combine them using a mixture proposal, 
which is a convex combination of base proposals: 


K 
q(x'|x) = X` wege(a'|a) (12.16) 
k=1 


where wx are the mixing weights that sum to one. As long as each gq, is an individually valid proposal, 
and each wz; > 0, then the overall mixture proposal will also be valid. In particular, if each proposal 
is reversible, so it satisfies detailed balance (Section 2.6.4.4), then so does the mixture. 
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It is also possible to compose individual proposals by chaining them together to get 


g(a’) =X X al(æi|e)qg(x|z1) +++ ax (ælær-1) (12.17) 


TK-—1 


A common example is where each base proposal only updates a subset of the variables (see e.g., 
Section 12.3). 


12.2.3.4 Data-driven MCMC 


In the case where the target distribution is a posterior, p* (x) = p(a|D), it is helpful to condition the 
proposal not just on the previous hidden state, but also the visible data, i.e., to use q(æx'|æ, D). This 
is called data-driven MCMC (see e.g., [TZ02; Jih+12]). 

One way to create such a proposal is to train a recognition network to propose states using 
q(xz'|x,D) = f(x). If the state space is high-dimensional, it might be hard to predict all the hidden 
components, so we can alternatively train individual “experts” to predict specific pieces of the hidden 
state. For example, in the context of estimating the 3d pose of a person from an image, we might 
combine a face detector with a limb detector. We can then use a mixture proposal of the form 


g(a’ |x, D) = mogo(a' |) + $` mean (al fe(D)) (12.18) 
k 


where qo is a standard data-independent proposal (e.g., random walk), and q, updates the k’th 
component of the state space. 


The overall procedure is a form of generate and test: the discriminative proposals q(a’|x,D) 
p(x’ |D) 
p(a|D) ? 


generate new hypotheses, which are then “tested” by computing the posterior ratio to see if 


28 the new hypothesis is better or worse. (See also Section 13.3, where we discuss learning proposal 


distributions for particle filters.) 


= 12.2.3.5 Adaptive MCMC 


One can change the parameters of the proposal as the algorithm is running to increase efficiency. 
This is called adaptive MCMC. This allows one to start with a broad covariance (say), allowing 
large moves through the space until a mode is found, followed by a narrowing of the covariance to 


36 ensure careful exploration of the region around the mode. 


However, one must be careful not to violate the Markov property; thus the parameters of the 
proposal should not depend on the entire history of the chain. It turns out that a sufficient condition 
to ensure this is that the adaption is “faded out” gradually over time. See e.g., [AT08] for details. 


— 12.2.4 Initialization 


43 It is necessary to start MCMC in an initial state that has non-zero probability. A natural approach 


is to first use an optimizer to find a local mode. However, at such points the gradients of the log 
joint are zero, which can cause problems for some gradient-based MCMC methods, such as HMC 
(Section 12.5), so it can be better to start “close” to a MAP estimate (see e.g., [HF M17, Sec 7.]). 
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12.3. GIBBS SAMPLING 


12.3 Gibbs sampling 


The major problems with MH are the need to choose the proposal distribution, and the fact that the 
acceptance rate may be low. In this section, we describe an MH method that exploits conditional 
independence properties of a graphical model to automatically create a good proposal, with acceptance 
probability 1. This method is known as Gibbs sampling.’ (In physics, this method is known as 
Glauber dynamics or the heat bath method.) This is the MCMC analog of coordinate descent.’ 


12.3.1 Basic idea 


The idea behind Gibbs sampling is to sample each variable in turn, conditioned on the values of all 
the other variables in the distribution. For example, if we have D = 3 variables, we use 


exp! ~ p(zı|z3, 23) 


o x3" ~ p(zalzi"", 23) 


e a3) ~ p(zalzi "23 


This readily generalizes to D variables. (Note that if x; is a known variable, we do not sample it, but 
it may be used as input to th another conditional distribution.) 

The expression p(xi|æ—;) is called the full conditional for variable i. In general, x; may only 
depend on some of the other variables. If we represent p(x) as a graphical model, we can infer 
the dependencies by looking at i’s Markov blanket, which are its neighbors in the graph (see 
Section 4.2.4.3), so we can write 

wit ~ p(aile®;) = p(eilenni) (12.19) 
(Compare to the equation for mean field variational inference in Equation (10.28).) 

We can sample some of the nodes in parallel, without affecting correctness. In particular, suppose 
we can create a coloring of the (moralized) undirected graph, such that no two neighboring nodes 
have the same color. (In general, computing an optimal coloring is NP-complete, but we can use 
efficient heuristics such as those in [Kub04].) Then we can sample all the nodes of the same color in 
parallel, and cycle through the colors sequentially [Gon+11]. 


12.3.2 Gibbs sampling is a special case of MH 


It turns out that Gibbs sampling is a special case of MH where we use a sequence of proposals of the 
form 


1 


qi(x'|æ) = p(xj|e_,)I (£L; = x-i) (12.20) 


That is, we move to a new state where x; is sampled from its full conditional, but æ—; is left 
unchanged. 


2. Josiah Willard Gibbs, 1839-1903, was an American physicist. 
3. Several software libraries exist for applying Gibbs sampling to general graphical models, including Nimble, which is 
a C++ library with an R wrapper, and which replaces older programs such as BUGS and JAGS. 
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Figure 12.2: Illustration of checkerboard pattern for a 2d MRF. This allows for parallel updates. 


We now prove that the acceptance rate of each such proposal is 100%, so the overall algorithm 
also has an acceptance rate of 100%. We have 


P(x )ai(wla’) _ piril pl pilates) 


_ ) 

= “p(a)ai(e'|x) — p(aila—i)p(a_—:)p(@!|@_1) (12.21) 
= p(x;,|x_i)p(x_:)p(wi|x_i) _ 
 p(x;|e_,)p(e_;)p(ai|a_,) : (12.22) 


where we exploited the fact that æ; = £i. 

The fact that the acceptance rate is 100% does not necessarily mean that Gibbs will converge 
rapidly, since it only updates one coordinate at a time (see Section 12.3.7). However, if we can group 
together correlated variables, then we can sample them as a group, which can significantly help 
mixing. 


12.3.3 Example: Gibbs sampling for Ising models 


28 In Section 4.3.2.1, we discuss Ising models and Potts models, which are pairwise MRFs with a 2d 
22 grid structure. The joint distribution has the form 


1 
= ZL] bis (ri. zl0) (12.28) 


Ing 


32 where i ~ j means i and j are neighbors in the graph. 


To apply Gibbs sampling to such a model, we just need to iteratively sample from each full 


29 conditional: 


Q jp jw N iIe IO io lœ |N 


plzi|£z—i) x II Wig (Li, £3) (12.24) 


jEnbr(i) 


41 Note that although Gibbs sampling is a sequential algorithm, we can sometimes exploit conditional 
42 independence properties to perform parallel updates [RS97a]. In the case of a 2d grid, we can color 
43 code nodes using a checkerboard pattern shown in Figure 12.2. This has the property that the black 


nodes are conditionally independent of each other given the white nodes, and vice versa. Hence we 


45 can sample all the black nodes in parallel (as a single group), and then sample all the white nodes, 


46 etc. 


IS 
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sample 5, Gibbs mean after 15 sweeps of Gibbs 


(b) 


Figure 12.3: Example of image denoising using Gibbs sampling. We use an Ising prior with J = 1 and a 
Gaussian noise model with o = 2. (a) Sample from the posterior after one sweep over the image. (b) Sample 
after 5 sweeps. (c) Posterior mean, computed by averaging over 15 sweeps. Compare to Figure 10.3 which 
shows the results of mean field inference. Generated by ising image _denoise_ demo.ipynb. 


To perform the sampling, we need to compute the full conditional in Equation (12.24). In the 
case of an Ising model with edge potentials Y(x;, £j) = exp(Ja;7;), where x; E {—1, +1}, the full 
conditional becomes 


Ilenia Viji = +1, 25) 


TT jenbra) V(x: = +1, zj) + Ijero) V(x; ==, Ti) 
exp[J X jenbr(i) zj] (12 26) 
exp[J pene. zj] + exp[—J ern zj] l 
exp[Jm:] Sei (12.27) 


exp[Jn] + exp[-Jn:] 
where J is the coupling strength, n; + J jenbr(i) Tj and o(u) = 1/(1 + e™) is the sigmoid function. 
(If we use x; € {0,1}, this becomes p(x; = +1|x_;) = o(Jni).) It is easy to see that n; = x; (a; — di), 
where a; is the number of neighbors that agree with (have the same sign as) node i, and d; is the 
number of neighbors who disagree. If this number is equal, the “forces” on x; cancel out, so the full 
conditional is uniform. Some samples from this model are shown in Figure 4.17. 

One application of Ising models is as a prior for binary image denoising problems. In particular, 
suppose y is a noisy version of z, and we wish to compute the posterior p(a|y) « p(a)p(y|x), where 
p(«) is an Ising prior, and p(y|x) = []; p(yi|x:) is a per-site likelihood term. Suppose this is a 
Gaussian. Let 7);(z;) = N (yilx;, 0°) be the corresponding “local evidence” term. The full conditional 
becomes 


exp[Jni]bi(+1) 


plzi = +1] a_i, y) = Sn 4 ee (12.28) 
= vi(+1) 
=o (2, log ot) (12.29) 


Now the probability of x; entering each state is determined both by compatibility with its neighbors 
(the Ising prior) and compatibility with the data (the local likelihood term). 

See Figure 12.3 for an example of this algorithm applied to a simple image denoising problem. The 
results are similar to the mean field results in Figure 10.3. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO l% IN IQ Jot TB IW IN IR 


IIB N ISIS IS è IS IS 1S le le IR ls la le le Is IB Ie 


A Je je |e PR je Je Je Joo Jw joo fw fo Jw jw jw jo joo 
IS JS IÀ IR JE I IBIS IS WIS IS JR IR 18 IS IB IS 


12.3.4 Example: Gibbs sampling for Potts models 


We can extend Section 12.3.3 to the Potts models as follows. Recall that the model has the following 
form: 


ple) = 5 exp(-€(e)) (12.30) 
E(x) = -JX I(t: = z3) (12.31) 


For a node į with neighbors nbr(i), the full conditional is thus given by 


= exp(J J nenbr(é) I (En = k)) 
eRe exp(J > nenbr(é) I (En = k')) 


So if J > 0, a node 7 is more likely to enter a state k if most of its neighbors are already in state k, 
corresponding to an attractive MRF. If J < 0, a node 7 is more likely to enter a different state from 
its neighbors, corresponding to a repulsive MRF. See Figure 4.18 for some samples from this model 
created using this method. 


12.3.5 Example: Gibbs sampling for GMMs 


In this section, we consider sampling from a Bayesian Gaussian mixture model of the form 


K 
p(0) = Dir(r|aæ) | [| M(H, |270, Vo)IW (Ex, So, vo) (12.34) 
k=1 


12.3.5.1 Known parameters 


32 Suppose, initially, that the parameteters 0 are known. We can easily draw independent samples from 
33 p(x|0) by using ancestral sampling: first sample z and then æ. However, for illustrative purposes, we 
34 will use Gibbs sampling to draw correlated samples. The full conditional for p(a|z = k,@) is just 
35 N(a|4,, Xp), and the full conditional for p(z = k|a) is given by Bayes rule: 


= _ TKN (al My, Be) 
p(z = k\x,0) = SCs (12.35) 


An example of this procedure, applied to a mixture of two 1D Gaussians with means at —20 and 


41 +20, is shown in Figure 12.4. We see that the samples are auto correlated, meaning that if we are 
42 in state 1, we will likely stay in that state for a while, and generate values near u1; then we will 
43 stochastically jump to state 2, and stay near there for a while, etc. (See Section 12.6.3 for a way to 
44 measure this.) By contrast, independent samples from the joint would not be correlated at all. 


In Section 12.3.5.2, we modify this example to sample the parameters of the GMM from their 


46 posterior, p(@|D), instead of sampling from p(D|@). 
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Figure 12.4: (a) Some samples from a mixture of two 1d Gaussians generated using Gibbs sampling. Color 
denotes the value of z, vertical location denotes the value of x. Horizontal axis represents time (sample 
number). (b) Traceplot of x over time, and the resulting empirical distribution is shown in blue. The true 
distribution is shown in red. Compare to Figure 12.1. Generated by mcemc_gmm_ demo.ipynb. 


12.3.5.2 Unknown parameters 


Now suppose the parameters are unknown, so we want to fit the model to data. If we use a 
conditionally conjugate factored prior, then the full joint distribution is given by 


(12.36) 
k=1 
N K 
= Mi JI aN æla 3) x (12.37) 
i=1Lk=1 
K 
Dir(r|æ) | [| NV (44\270, Vo)IW(Z|So, vo) (12.38) 
k=1 
We use the same prior for each mixture component. 
The full conditionals are as follows. For the discrete indicators, we have 
(zi = k|£i, p, Z, T) x TN xi] Me, Ue) (12.39) 
For the mixing weights, we have (using results from Section 3.2.2) 
N 
p(m|z) = Dir({ar + XDI (zi = k) HS) (12.40) 
i=1 
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For the means, we have (using results from Section 3.3.1) 


p( Uz, |e, Z, £) = N (;,|Mx, Ve) 
Vi =Vp +NeDy 


Mk = Vi(Dy Nee + Vo 'mo) 


(12.41) 
(12.42) 
(12.43) 


(12.44) 


(12.45) 


For the covariances, we have (using results from Section 3.3.2) 


P(UK|My, z, £) = W(2x|Sx, ve) 


N 
Sr = So + X I (z; = k) (ai — oy) (@i — My)" 
j=l 


Vk = vo + Nk 


12.3.6 Metropolis within Gibbs 


(12.46) 
(12.47) 


(12.48) 


When implementing Gibbs sampling, we have to sample from the full conditionals. If the distributions 
are conjugate, we can compute the full conditional in closed form, but in the general case, we will 
need to devise special algorithms to sample from the full conditionals. 

One approach is to use the MH algorithm; this is called Metropolis within Gibbs. In particular, 


s+1 


to sample from 2; 


1. Propose z! ~ q(x'|x3) 


~ p(x; |est*,, £ 1:p), we proceed in 3 steps: 


2. Compute the acceptance probability A; = min(1,a;) where 


_ ple tei) / ae (22) 


pleit o 2i £f p) /ale;le,) 


3. Sample u ~ U (0, 1) and set sêt! = z! if u < Aj, and set x 


4 


12.3.7 Blocked Gibbs sampling 


(12.49) 


s+1 


¿= x} otherwise. 


Gibbs sampling can be quite slow, since it only updates one variable at a time (so-called single 


41 site updating). If the variables are highly correlated, the chain will move slowly through the state 
42 space. This is illustrated in Figure 12.5, where we illustrate sampling from a 2d Gaussian. The ellipse 
43 represents the covariance matrix. The size of the moves taken by Gibbs sampling is controlled by the 


variance of the conditional distributions. If the variance is £ along some coordinate direction, but the 
support of the distribution is L along this dimension, then we need O((L/£)?) steps to obtain an 


independent sample. 
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Figure 12.5: Illustration of potentially slow sampling when using Gibbs sampling for a skewed 2D Gaussian. 
Adapted from Figure 11.11 of [Bis06]. Generated by gibbs_ gauss_ demo.ipynb. 


In some cases we can efficiently sample groups of variables at a time. This is called blocked 
Gibbs sampling [JKK95; WY02], and can make much bigger moves through the state space. 

As an example, suppose we want to perform Bayesian inference for a state-space model, such as an 
HMM, i.e., we want to sample from 


T 
p(0, z|x) x p(8 [fx &1|Zt, 8) p(z|21-1, 9) (12.50) 
t=1 


We can use blocked Gibbs sampling, where we alternate between sampling from p(0|z, £) and p(z|ax, 0); 
The former is easy to do (assuming conjugate priors), since all variables in the model are observed 
(see Section 29.8.4.1). The latter can be done using the forwards-filtering backwards-sampling 
(Section 8.2.8). 


12.3.8 Collapsed Gibbs sampling 


We can sometimes gain even greater speedups by analytically integrating out some of the unknown 
quantities. This is called a collapsed Gibbs sampler, and it tends to be more efficient, since 
it is sampling in a lower dimensional space. This can result in lower variance, as discussed in 
Section 11.6.2. 

As an example, consider a GMM with a fully conjugate prior. This can be represented as a DPGM 
as shown in Figure 12.6a. Since the prior is conjugate, we can analytically integrate out the model 
parameters Hg, 4, and m, so the only remaining hidden variables are the discrete indicator variables 
z. However, once we integrate out m, all the z; nodes become inter-dependent. Similarly, once we 
integrate out Ok = (Hp, Ux), all the x; nodes become inter-dependent, as shown in Figure 12.6b. 
Nevertheless, we can easily compute the full conditionals, and hence implement a Gibbs sampler, as 
we explai below. In particular, the full conditional for the latent indicators is given by 


D(z B k|z_;, x, a, p) x P(z = k|z_;, a, B)p(a|z; = k, 24,84, B) (12.51) 
x p(zi = k|z-i, «)p(ai|v-i, zi = k, z-i, B) 


plæ—il=K, zi, B) (12.52) 
x p( zi = klz_, a)p(x;|x_i, Zi = k, Zi; B) (12.53) 
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Figure 12.6: (a) A mixture model represented as an “unrolled” DPGM. (b) After integrating out the continuous 
latent parameters. 


where B = (mo, Vo, So, vo) are the hyper-parameters for the class-conditional densities. We now 
discuss how to compute these terms. 

Suppose we use a symmetric prior of the form m ~ Dir(@), where a, = a/K, for the mixing 
weights. Then we can obtain the first term in Equation (12.53), from Equation (3.29), where 


K 
T(N; +a/K) 
eer 12.54 
pla, zNa) = PUR Ir T(a/K) ( 5 ) 
Hence 
1 
. IN Fa) T(Nk K 

(ee. =o) ee, A) (12.55) 

p(z_i|@) TNFa i) T(Nz,-i +a/K) 

T(N —1)I(Nk -i +1 K Ning 
_T(N ba-I)P(Nait 1 +0/K) _ Naita se 


T(N +a) T(Ny-i+a/K)  N+a-1 


where N;,-; = 2 nal (zn = k) = Nę — 1, and where we exploited the fact that T(x +1) = zT (x). 
To obtain the second term in Equation (12.53), which is the posterior predictive distribution for 


z x, given all the other data and all the assignments, we use the fact that 


p(xi|e_4, Zi, 21 = k, B) = p(æ;|D—i, k B) (12.57) 


41 where Dik = {xj : zj = k, j #4 i} is all the data assigned to cluster k except for x;. If we use a 
42 conjugate prior for Opg, we can compute p(x;|D—i,k, B) in closed form. Furthermore, we can efficiently 
43 update these predictive likelihoods by caching the sufficient statistics for each cluster. To compute 


the above expression, we remove 2;’s statistics from its current cluster (namely z;), and then evaluate 


45 a; under each cluster’s posterior predictive distribution. Once we have picked a new cluster, we add 


zi’s statistics to this new cluster. 
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12.4. AUXILIARY VARIABLE MCMC 


Some pseudo-code for one step of the algorithm is shown in Algorithm 27, based on [Sud06, p94]. 
(We update the nodes in random order to improve the mixing time, as suggested in [RS97b].) We 
can initialize the sample by sequentially sampling from p(z;|Z1.;-1, 1). In the case of GMMs, both 
the naive sampler and collapsed sampler take O(N K D) time per step. 


Algorithm 27: Collapsed Gibbs sampler for a mixture model 


1 for eachi=1:N in random order do 

2 Remove «;’s sufficient statistics from old cluster z; 
3 for each k = 1 : K do 
4 


| Compute px(æ:|8) = plæil{æ; : z; = k, j # i}, B) 


5 Compute p(z; = k|z-i;,@) x (Nk,—i +a/K)prk(xi) 
6 Sample z; ~ p(zi|-) 
7 Add 2;’s sufficient statistics to new cluster z; 


The primary advantage of using the collapsed sampler is that it extends to the case where we 
have an “infinite” number of mixture components, as in the Dirichlet process mixture model of 
Section 31.3.2. 


12.4 Auxiliary variable MCMC 


Sometimes we can dramatically improve the efficiency of sampling by introducing auxiliary variables, 
in order to reduce correlation between the original variables. If the original variables are denoted by 
zx, and the auxiliary variables by v, then the augmented distribution becomes p(x, v). We assume it 
is easier to sample from this than the marginal distribution p(x). If so, we can draw joint samples 
(x°, v) ~ p(x, v), and then just “throw away” the vê, and the result will be samples from the desired 
marginal, xê ~ >>, p(a,v). We give some examples of this below. 


12.4.1 Slice sampling 


Consider sampling from a univariate, but multimodal, distribution p(x) = p(x)/Z,, where p(x) is 
unnormalized, and Z, = f p(x)dx. We can sometimes improve the ability to make large moves by 
adding a uniform auxiliary variable v. We define the joint distribution as follows: 


(12.58) 


. _ f 1/Zp if0<v< p(z) 
pls, v) = { 0 otherwise 


The marginal distribution over x is given by 


B(x) B(x 
Jiena =j zi = E = p(x) (12.59) 


so we can sample from p(x) by sampling from p(x, v) and then ignoring v. To do this, we will use a 
technique called slice sampling [Nea03]. 
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f(x) 


Figure 12.7: Slice sampling. (a) Illustration of one step of the algorithm in 1d. Given a previous sample x‘, we 
sample u't* uniformly on [0, f(a’)], where f =p is the (unnormalized) target density. We then sample zt! 
along the slice where f(x) > utt. From Figure 15 of [And +03]. Used with kind permission of Nando de Freitas. 
(b) Output of slice sampling applied to a 1d distribution. Generated by slice sampling demo 1d.ipynb. 


Aqisuap 10uaysodT—2T 


= Figure 12.8: Posterior for binomial regression for 1d data. Left: Slice sampling approximation. Right: Grid 
= approximation. Generated by slice_ sampling _ demo_ 2d.ipynb. 


This works as follows. Given previous sample x’, we sample v‘+! from 


plola’) = Ulogey () (12.60) 


38 This amounts to uniformly picking a point on the vertical line between 0 and p(x"), We use this to 
39 construct a “slice” of the density at or above this height, by computing A+! = {x : p(x) > vitt}. 
40 We then sample z+! uniformly from ths set. See Figure 12.7(a) for an illustration. 


To compute the level set A, we can use an iterative search procedure called stepping out, in which 


42 we start with an interval tmin < £ < maz around the current point x’ of some width, and then we 
43 keep extending it until the endpoints fall outside the slice. We can then use rejection sampling to 
44 sample from the interval. For the details, see [Nea03]. 


To apply the method to multivariate distributions, we sample one extra auxiliary variable for 


46 each dimension. Thus we perfom 2D sampling operations to draw a single joint sample, where 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o e lw N |e 


Io IS ls la ie le Ie IE 


IS Is 


12.4. AUXILIARY VARIABLE MCMC 


Disconnect edges between nodes with different assignments 


@ 0-0-0- 


oò ooo 

o-o Oa Cr 
ooo @ > 
@e@0e@00 
e-e o o òo 


Q 


HUUR 


Select a connected component and Activate remaining edges with some 
assign it to a different label probability and find connected components 


: 


Figure 12.9: Illustration of the Swendsen Wang algorithm on a 2d grid. Used with kind permission of Kevin 
Tang. 


D is the number of random variables. The advantage of this over Gibbs sampling applied to the 
original (non-augmented) distribution is that it only needs access to the unnormalized joint, not the 
full-conditionals. 

Figure 12.7(b) illustrates the algorithm in action on a synthetic 1d problem. Figure 12.8 illustrates 
its behavior on a slightly harder problem, namely binomial logistic regression. The model has the 
form y; ~ Bin(nj, logit(81ı + 62x;)). We use a vague Gaussian prior for the 8;’s. On the left we 
show the slice sampling approximation to the posterior, and on the right we shpw a grid-based 
approximation, as a simple deteterministic proxy for the true posterior. We see a close correpondence. 


12.4.2 Swendsen Wang 


Consider an Ising model of the following form: p(x) = + [Į], U(w-), where £e = (£i, £j) for edge 


J =J 
e = (i,j), xi E {+1,—1}, and the edge potential is defined by (E 3 | where J is the edge 
strength. In Section 12.3.3, we discussed how to apply Gibbs sampling to this model. However, this 
can be slow when J is large in absolute value, because neighboring states can be highly correlated. 
The Swendsen Wang algorithm [SW87b] is an auxiliary variable MCMC sampler which mixes 
much faster, at least for the case of attractive or ferromagnetic models, with J > 0. 
Suppose we introduce auxiliary binary variables, one per edge.* These are called bond variables, 


and will be denoted by v. We then define an extended model p(x,v) of the form p(x,v) = 


4. Our presentation of the method is based on notes by David Mackay, available from http://www.inference.phy. 
cam.ac.uk/mackay/itila/swendsen.pdf. 
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e] eI e? — e71 0 
W (Xe, Ve = 0) = (o =) 5 W (Xe, Ve = 1) = ( 0 J -) (12.61) 


It is clear that eae U(xe, Ve) = V(ze), and hence that X, p(x, v) = p(x), as required. 

Fortunately, it is easy to apply Gibbs sampling to this extended model. The full conditional p(v|a) 
factorizes over the edges, since the bond variables are conditionally independent given the node 
variables. Furthermore, the full conditional p(ve|£e) is simple to compute: if the nodes on either end 
of the edge are in the same state (x; = xj), we set the bond ve to 1 with probability p = 1 — e~?/, 
otherwise we set it to 0. In Figure 12.9 (top right), the bonds that could be turned on (because their 
corresponding nodes are in the same state) are represented by dotted edges. In Figure 12.9 (bottom 
right), the bonds that are randomly turned on are represented by solid edges. 

To sample p(a|v), we proceed as follows. Find the connected components defined by the graph 
induced by the bonds that are turned on. (Note that a connected component may consist of a 
singleton node.) Pick one of these components uniformly at random. All the nodes in each such 
component must have the same state. Pick a state +1 uniformly at random, and set all the variables 
in this component to adopt this new state. This is illustrated in Figure 12.9 (bottom right), where 
the green square denotes the selected connected component; we set all the nodes within this square 
to white, to get the bottom left configuration. 

It should be intuitively clear that Swendsen Wang makes much larger moves through the state space 
than Gibbs sampling. The gains are exponentially large for certain settings of the edge parameter. 
More precisely, let the edge strength be parameterized by J/T, where T > 0 is a computational 
temperature. For large T, the nodes are roughly independent, so both methods work equally well. 
However, as T approaches a critical temperature Te, the typical states of the system have very 


27 long correlation lengths, and Gibbs sampling takes a very long time to generate independent samples. 
23 As the temperature continues to drop, the typical states are either all on or all off. The frequency 


with which Gibbs sampling moves between these two modes is exponentially small. By contrast, SW 
mixes rapidly at all temperatures. 

Unfortunately, if any of the edge weights are negative, J < 0, the system is frustrated, and there 
are exponentially many modes, even at low temperature. SW does not work very well in this setting, 


33 since it tries to force many neighboring variables to have the same state. In fact, sampling from these 


kinds of frustrated systems is provably computationally hard for any algorithm [JS93; JS96]. 


= 12.5 Hamiltonian Monte Carlo (HMC) 


= Many MCMC algorithms perform poorly in high dimensional spaces, because they rely on a form 


— of random search based on local perturbations. In this section, we discuss a method known as 
— Hamiltonian Monte Carlo or HMC, that leverages gradient information to guide the local moves. 
— This is an auxiliary variable method (Section 12.4) derived from physics [Dua+87; Nea93; Mac03; 
“= Neal0; Bet17].° In particular, the method builds on Hamiltonian mechanics, which we describe 


= below. 


=2 5. The method was originally called hybrid MC [Dua+87]. It was introduced to the statistics community in [Nea93], 
46 and was renamed to Hamiltonian MC in [Mac03]. 
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12.5. HAMILTONIAN MONTE CARLO (HMC) 


12.5.1 Hamiltonian mechanics 


Consider a particle rolling around an energy landscape. We can characterize the motion of the particle 
in terms of its position O € R? (often denoted by q) and its momentum v € R? (often denoted by 
p). The set of possible values for (8, v) is called the phase space. We define the Hamiltonian 
function for each point in phase space as follows: 


H(0,v) = E(0) + K(v) (12.62) 


where €(0) is the potential energy, K(v) is the kinetic energy, and the Hamiltonian is the total 
energy. In a physical setting, the potential energy is due to the pull of gravity, and the momentum is 
due to the motion of the particle. In a statistical setting, we often take the potential energy to be 


E(0) = — log p(@) (12.63) 


where p(0) is a possibly unnormalized distribution, such as p(@,D), and the kinetic energy to be 
K(v) = =v Sty (12.64) 


where & is a positive definite matrix, known as the inverse mass matrix. 

Stable orbits are defined by trajectories in phase space that have a constant energy, The trajectory 
of a particle within an energy level set can be obtained by solving the following continuous time 
differential equations, known as Hamilton’s equations: 


dð OH OK 
dt ôv dv 
dv OH OE oe) 
dt 00 00 
To see why energy is conserved, note that 
D D 
dH. OH dO; | OH dv; | _ OHOH OHOH] _ 


i=l i=1 


Intuitively, we can understand this result as follows: a satellite in orbit around a planet will “want” 
to continue in a straight line due to its momentum, but will get pulled in towards the planet due 
to gravity, and if these forces cancel, the orbit is stable. If the satellite starts spiraling towards the 
planet, its kinetic energy will increase but its potential energy will decrease. 

Note that the mapping from (O(t), v(t)) to (O0(t+ s), v(t+s)) for some time increment s is invertible 
for small enough time steps. Furthermore, this mapping is volume preserving, so has a Jacobian 
determinant of 1. (See e.g., [BZ20, p287] for a proof.) These facts will be important later when we 
turn this system into an MCMC algorithm. 


12.5.2 Integrating Hamilton’s equations 


In this section, we discuss how to simulate Hamilton’s equations in discrete time. 
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12.5.2.1 Euler’s method 


The simplest way to model the time evolution is to update the position and momentum simultaneously 
by a small amount, known as the step size n: 


dv oE (0 

Vi =V + ng O» v) = v(t) -—7 i) (12.67) 
do OK(w 

O41 =O + ng (Ot Me) =O,+ ete) (12.68) 


If the kinetic energy has the form in Equation (12.64) then the second expression simplifies to 
6141 =O, +9 wa (12.69) 


This is known as Euler’s method. 


12.5.2.2 Modified Euler’s method 


The Modified Euler’s method is slightly more accurate, and works as follows: first update the 
momentum, and then update the position using the new momentum: 


dv oE(0 

Veq1 = V + ng O» v) =v -nN (12.70) 
do OK(w 

O41 = O + n (Ot, vi) =0, + ees) (12.71) 


Unfortunately, the asymmetry of this method can cause some theoretical problems (see e.g., [BZ20, 
p287|) which we resolve below. 


= 12.5.2.3 Leapfrog integrator 


In this section, we discuss the leapfrog integratorm which is a symmetrized version of the modified 


31 Euler method. We first perform a “half” update of the momentum, then a full update of the position, 
32 and then finally another “half” update of the momentum: 


OE (0 
Vi41/2 = Ue — i) (12.72) 
OK(u 
en ee 2 (12.73) 
oE(0 
Ut+1 = Vt+1/2 = T (12.74) 


If we perform multiple leapfrog steps, it is equivalent to performing a half step update of v at the 


= beginning and end of the trajectory, and alternating between full step updates of 0 and v in between. 


— 12.5.2.4 Higher order integrators 


A Ià Ià IÈ 


45 It is possible to define higher order integrators that are more accurate, but take more steps. For 
46 details, see [BRSS18]. 
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12.5. HAMILTONIAN MONTE CARLO (HMC) 


12.5.3 The HMC algorithm 


We now describe how to use Hamiltonian dynamics to define an MCMC sampler in the expanded 
state space (0, v). The target distribution has the form 


nies Zoe [-11(6,v)] = ze] £(0) 5orde (12.75) 


The marginal distribution over the latent variables of interest has the form 


_ ee) / 1 -isv 1 ee) 
p0) = J evd = Z,° A 2 dv = Zi, (12.76) 

Suppose the previous state of the Markov chain is (0;_1,v;_1). To sample the next state, we 
proceed as follows. We set the initial position to 0) = 0;-1, and sample a new random momentum, 
vy ~ N(0, £). We then initialize a random trajectory in the phase space, starting at (0), vo), and 
followed for L leapfrog steps, until we get to the final proposed state (0*, v*) = (04, v% ). If we have 
simulated Hamiltonian mechanics correctly, the energy should be the same at the start and end 
of this process; if not, we say the HMC has diverged, and we reject the sample. If the energy is 
constant, we compute the MH acceptance probability 


a = min (i gg) = min (1, exp [-H(0*, v*) + H(O:_1, v:_1)]) (12.77) 


(The transition probabilities cancel since the proposal is reversible.) Finally, we accept the proposal 
by setting (0+, v+) = (0*, v*) with probability a, otherwise we set (0+, v+) = (0:1, vi_1). (In practice 
we don’t need to keep the momentum term, it is only used inside of the leapfrog algorithm.) See 
Algorithm 28 for the pseudocode.° 


Algorithm 28: Hamiltonian Monte Carlo 


1 fort=1:T do 

2 | Generate random momentum v- ~ N (0, X) 
3 Set (06,6) = (O:-1, Vi-1) 

4 | Half step for momentum: v, = vg — $VE(O) 


5 for l=1:L-— 1 do 
6 0i = 0i_ı + nE Yi aja 
i Vig1/2 = M12 — NVEO) 


8 Full step for location: 65, = 64, + Dia ay 

9 | Half step for momentum: v} = v7 _ 4/2 — 3VE(Oz) 
10 Compute proposal (6*, v*) = (0%, v4.) 
11 Compute a = min (1, exp|-H(0*, v*) + H(@:-1, vt-1)]) 
12 Set 0; = 0* with probability a, otherwise 0; = 0—1. 


6. There are many high-quality implementations of HMC. For example, BlackJAX in JAX. 
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We need to sample a new momentum at each iteration to satisfy ergodicity. To see why, recall that 
H(6,v) stays approximately constant as we move through phase space. If H(0,v) = E(0) + iol dv, 
then clearly €(@) < H(0, v) = h for all locations 0 along the trajectory. Thus the sampler cannot 
reach states where €(@) > h. To ensure the sampler explores the full space, we must pick a random 
momentum at the start of each iteration. 


12.5.4 Tuning HMC 


We need to specify three hyperparameters for HMC: the number of leapfrog steps L, the step size 1, 
and the covariance X. 


12.5.4.1 Choosing the number of steps using NUTS 


We want to choose the number of leapfrog steps L to be large enough that the algorithm explores 
the level set of constant energy, but without doubling back on itself, which would waste computation, 
due to correlated samples. Fortunately, there is an algorithm, known as the No-U-Turn Sampler 
or NUTS algorithm [HG14], which can adaptively choose L for us. 


12.5.4.2 Choosing the step size 


When » = I, the ideal step size 7 should be roughly equal to the width of €(@) in the most constrained 
direction of the local energy landscape. For a locally quadratic potential, this corresponds to the 


28 square root of the smallest marginal standard deviation of the local covariance matrix. (If we think 
27 of the energy surface as a valley, this corresponds to the direction with the steepest sides.) A step 
28 size much larger than this will cause moves that are likely to be rejected because they move to places 
= which increase the potential energy too much. On the other hand, if the step size is too low, the 
2 proposal distribution will not move much from the starting position, and the algorithm will be very 


2. slow. 


A Je Jẹ Je JA Je Je Jè Jw [ww jœ lw jw jw jw jw 


In [BZ20, Sec 9.5.4] they recommend the following heuristic for picking 7: set X = I and L = 1, 


33 and then vary 7 until the acceptance rates are in the range of 40%-80%. Of course, different step 
2 sizes might be needed in different parts of the state space. In this case, we can use learning rate 
35 schedules from the optimization literature, such as cyclical schedules [Zha+ 20d]. 


2 12.5.4.3 Choosing the covariance (inverse mass) matrix 


To allow for larger step sizes, we can use a smarter choice for X, also called the inverse mass matrix. 


41 One way to estimate a fixed © is to run HMC with © = I for a warmup period, until the chain is 
42 “burned in” (see Section 12.6); then we run for a few more steps, so we can compute the empirical 
43 covariance matrix using © = E [(@ — 0)(@ — @)"]. In [Hof+19] they propose a method called NeuTra 
44 HMC algorithm which “neutralizes” bad geometry by learning an inverse autoregressive flow model 
45 (Section 23.2.4.3) in order to map the warped distribution to an isotropic Gaussian. This is often an 
46 order of magnitude faster than vanilla HMC. 
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12.5. HAMILTONIAN MONTE CARLO (HMC) 


12.5.5 Riemann Manifold HMC 


If we let the covariance matrix change as we move position, so ÙX is a function of 0, the method 
is known as Riemann Manifold HMC or RM-HMC [GC11; Bet13], since the moves follow a 
curved manifold, rather than the flat manifold induced by a constant X. 

A natural choice for the covariance matrix is to use the Hessian at the current location, to capture 
the local geometry: 


4(0) = V7E(0) (12.78) 


Since this is not always positive definite, an alternative, that can be used for some problems, is to 
use the Fisher information matrix (Section 2.4), given by 


E(0) = —E, (29) [V7 log p(a|9)] (12.79) 


Once we have computed (8), we can compute the kinetic energy as follows: 
1 1 
K(0,v) = 5 log((27)?|(6)|) + zv U(9)v (12.80) 


Unfortunately the Hamiltonian updates of 0 and v are no longer separable, which makes the RM-HMC 
algorithm more complex to implement, so it is not widely used. 


12.5.6 Langevin Monte Carlo (MALA) 


A special case of HMC occurs when we take L = 1 leapfrog steps. This is known as Langevin 
Monte Carlo (LMC), or Metropolis Adjusted Langevin Algorithm (MALA) [RT96]. This 
gives rise to the simplified algorithm shown in Algorithm 29. 


Algorithm 29: Langevin Monte Carlo 


1 fort =1:T7 do 
2 Generate random momentum v;_1 ~ (0, £) 


0* = 0, — Ty vela +E 4-1 
v= VUt-1 — ZVE(O:-1) = 2VE(0*) 

Compute a = min (1, exp|—-H(0*, v*)]/ exp -H (0:1-1, ve_1)]) 
Set 0; = 0* with probability a, otherwise 0; = 04-1. 


an A O 


A further simplification is to eliminate the MH acceptance step. In this case, the update becomes 


2 
0: = 01 — T SVEO, 1) +E toi (12.81) 


2 
=6,1- T SVEO, 1) +V 164 (12.82) 


where vz-1 ~ N (0, X) and eż+—1 ~ N (0, I). This is just like gradient descent with added noise. If we 
set X to be the Fisher information matrix, this becomes natural gradient descent (Section 6.4) with 
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added noise. If we approximate the gradient with a stochastic gradient, we get a method known as 
SGLD, or Stochastic Gradient Langevin Descent (see Section 12.7.1 for details). 

Now suppose © = I, and we set 7 = V2. In continuous time, we get the following stochastic 
differential equation (SDE), known as Langevin diffusion: 


d0, = —VE(0,)dt + V2dB, (12.83) 


where B; represents D-dimensional Brownian motion. If we use this to generate the samples, the 
method is known as the Unadjusted Langevin Algorithm or ULA [Par81; RT 96]. 


12.5.7 Connection between SGD and Langevin sampling 


In this section, we discuss a deep connection between stochastic gradient descent (SGD) and Langevin 
sampling, following the presentation of [BZ20, Sec 10.2.3]. 
Consider the minimization of the additive loss 


N 
= 5 Ly (0) (12.84) 


For example, we may define Ln (0) = — log p(yn|an, 0.) We will use a minibatch approximation to 
the gradients: 


VgL(0 E Wit (12.85) 
eee 
27 where S = {i1,..., ipg} is a randomly chosen set of indices of size B. For simplicity of analysis, we 
28 assume the indices are chosen with replacement from {1,..., N}. 
Let us define the (scaled) error (due to minibatching) in the estimated gradient by 
v = Vi(VL(6:) — V BL(0:)) (12.86) 
33 This is called the diffusion term. Then we can rewrite the SGD update as 
A441 = 8: — NV BLA) = 0i — nVL(A) + y (12.87) 
The diffusion term v; has mean 0, since 
B 12 
z [Vp L(0 =F} = [V£ (0)] = 3 2, VL(0) = VL(A) (12.88) 
j=l J= 


42 To compute the variance of the diffusion term, note that 


B 
V[Vs£(0 mL (12.89) 
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12.5. HAMILTONIAN MONTE CARLO (HMC) 


(a) 


Figure 12.10: (a) A mixture model. (b) After integrating out the discrete latent variables. 


where 


V [VLi,(8)] = E [VL (0)V £ (0)'] — E [V£ (0)] E [V£ (0)"] (12.90) 


II 


N 
(7 5 Yenevcn or) — VL(0)VL(0)" ê D(0) (12.91) 


where D(0) is called the diffusion matrix. Hence V [v;] = D(@). 
[LTW15] prove that the following continuous time stochastic differential equation is a first-order 
approximation of minibatch SGD (assuming the loss function is Lipshitz continuous): 


d0(t) = —VL(0(t))dt + ZD(6,)dB(t) (12.92) 


where B(t) is Brownian motion. Thus the noise from minibatching causes SGD to act like a Langevin 
sampler. (See [Hu+17] for more information.) 

The scale factor for the noise, rT = #4, plays the role of temperature. Thus we see that using 
a smaller batch size is like using a larger temperature; the added noise ensures that SGD avoids 
going into narrow ravines, and instead spends most of its time in flat minima which have better 
generalization performance [Kes-+17]. See Section 17.4.1 for more discussion of this point. 


12.5.8 Applying HMC to constrained parameters 


To apply HMC, we require that all the latent quantities be continuous (real-valued) and have 
unconstrained support, i.e., © € RP, so discrete latent variables need to be marginalized out 
(although some recent work, such as [NDL20; Zho20], relaxes this requirement). 

As an example of how this can be done, consider a GMM. We can easily write the likelihood 
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without discrete latents as follows: 


K 
P(2nl0) = X` TN (tn bys De) (12.93) 
k=l 


The corresponding “collapsed” model is shown in Figure 12.10(b). (Note that this is the opposite of 
Section 12.3.8, where we integrated out the continuous parameters in order to apply Gibbs sampling 
to the discrete latents.) We can apply similar techniques to other discrete latent variable models. For 
example, to apply HMC to HMMs, we can use the forwards algorithm (Section 8.2.2) to efficiently 
compute p(%n|9) = do... P(n, Zn,1:7|8). 

In addition to marginalizing out any discrete latent variables, we need to ensure the remaining 
continuous latent variables are unconstrained. This often requires performing a change of variables 
using a bijector. For example, instead of sampling the discrete probability vector from the probability 
simplex m € S*, we should sample the logits 7 € R. After sampling, we can transform back, since 
bijectors are invertible. (For a practical example, see change _of variable _hmce.ipynb.) 


12.5.9 Speeding up HMC 


Although HMC uses gradient information to explore the typical set, sometimes the geometry of the 
typical set can be difficult to sample from. See Section 12.5.4.3 for ways to estimate the mass matrix, 
which can help with such difficult cases. 

Another issue is the cost of evaluating the target distribution, E(0) = — log p(@). For many ML 
applications, this has the form log p(@) = log po(0) + = log p(0n|0). This takes O(N) time to 
compute. We can speed this up by using stochastic gradient methods; see Section 12.7 for details. 


12.6 MCMC convergence 


32 We start MCMC from an arbitrary initial state. As we explained in Section 2.6.4, the samples will be 
33 coming from the chain’s stationary distribution only when the chain has “forgotten” where it started 


from. The amount of time it takes to enter the stationary distribution is called the mixing time (see 


35 Section 12.6.1 for details). Samples collected before the chain has reached its stationary distribution 
36 do not come from p*, and are usually thrown away. The initial period, whose samples will be ignored, 
37 is called the burn-in phase. 


For example, consider a uniform distribution on the integers {0,1,...,20}. Suppose we sample 
from this using a symmetric random walk. In Figure 12.11, we show two runs of the algorithm. On 
the left, we start in state 10; on the right, we start in state 17. Even in this small problem it takes 


41 over 200 steps until the chain has “forgotten” where it started from. Proposal distributions that 
42 make larger changes can converge faster. For example, [BD92; Man] prove that it takes about 7 riffle 
43 shuffles to properly mix a deck of 52 cards (i.e., to ensure the distribution is uniform). 


In Section 12.6.1 we discuss how to compute the mixing time theoretically. In practice, this can be 


45 very hard [BBM10] (this is one of the fundamental weaknesses of MCMC), so in Section 12.6.2, we 
46 discuss practical heurstics. 
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12.6. MCMC CONVERGENCE 


Initial Condition zo = 10 Initial Condition xp = 17 

Po(a) | I po(«) I 
pi(2) | I pı(z) Li 
p2(x) | I p2(x) eta 
ps(2x) | ti ps(2x) i 
pıo(x) | Lt pio(x) LJ! 
pıoo(2) | i i i i i i i I tet pi00(«) tel i if t i t I I ji i 
Pool) rrrrtrteteleleletriel Poo") seeretetrttifTitTlilill 
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0 5 10 15 20 0 5 10 15 20 

x x 
(a) (6) 
Figure 12.11: Illustration of convergence to the uniform distribution over {0,1,...,20} using a symmetric 


random walk starting from (left) state 10, and (right) state 17. Adapted from Figures 29.14 and 29.15 of 
[Mac03]. Generated by random_walk_integers.ipynb. 


12.6.1 Mixing rates of Markov chains 


The amount of time it takes for a Markov chain to converge to the stationary distribution, and forget 
its initial state, is called the mixing time. More formally, we say that the mixing time from state 
Xo is the minimal time such that, for any constant € > 0, we have that 


Te(£o) = min{t : ||6,,(x)T* — p*||1 < €} (12.94) 


where ôs, (x) is a distribution with all its mass in state xo, T is the transition matrix of the chain 
(which depends on the target p* and the proposal q), and 6,,(x)T" is the distribution after t steps. 
The mixing time of the chain is defined as 


Te = max Te(£0) (12.95) 
xo 


This is the maximum amount of time it takes for the chain’s distribution to get € close to p* from 
any starting state. 

The mixing time is determined by the eigengap y = A; — 2, which is the difference between 
the first and second eigenvalues of the transition matrix. For a finie state chain, one cans show 
Te = o(4 log 2), where n is the number of states. 

We can also study the problem by examining the geometry of the state space. For example, 
consider the chain in Figure 12.12. We see that the state space consists of two “islands”, each of 
which is connected via a narrow “bottleneck”. (If they were completely disconnected, the chain 
would not be ergodic, and there would no longer be a unique stationary distribution, as discussed 
in Section 2.6.4.3.) We define the conductance ¢ of a chain as the minimum probability, over all 
subsets S of states, of transitioning from that set to its complement: 


i Veesaese l(t > x’) 
min ; 
S:0<p*(S)<0.5 p*(S) 


PE 


(12.96) 
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Figure 12.12: A Markov chain with low conductance. The dotted arcs represent transitions with very low 
probability. From Figure 12.6 of [KF09a]. Used with kind permission of Daphne Koller. 


One can show that Te < O (+ log 2). Hence chains with low conductance have high mixing time. 
For example, distributions with well-separated modes usually have high mixing time. Simple MCMC 
methods, such as MH and Gibbs, often do not work well in such cases, and more advanced algorithms, 
such as parallel tempering, are necessary (see e.g., [ED05; Kat+06; BZ20]). 


12.6.2 Practical convergence diagnostics 


Computing the mixing time of a chain is in general quite difficult, since the transition matrix is 
usually very hard to compute. Furthermore, diagnosing convergence is computationally intractable 
in general [BBM10]. Nevertheless, various heuristics have been proposed — see e.g., [Gey92; CC96; 


26 BR98; Veh+19]. We discuss some of the current recommended approaches below, following [Veh+19]. 


=~ 12.6.2.1 Trace plots 


One of the simplest approaches to assessing if the method has converged is to run multiple chains 


31 (typically 3 or 4) from very different overdispersed starting points, and to plot the samples of some 
32 quantity of interest, such as the value of a certain component of the state vector, or some event 
33 such as the value taking on an extreme value. This is called a trace plot. If the chain has mixed, 
34 it should have “forgotten” where it started from, so the trace plots should converge to the same 
35 distribution, and thus overlap with each other. 


To illustrate this, we will consider a very simple, but enlightening, example from [McE20, Sec 9.5]. 


37 The model is a univariate Gaussian, y; ~ N(a,o), with just 2 observations, yı = —1 and y2 = +1. 
38 We first consider a very diffuse prior, a ~ N (0, 1000) and o ~ Expon(0.0001), both of which allow 
39 for very large values of a and ø. We fit the model using HMC using 3 chains and 500 samples. The 


result is shown in Figure 12.13. On the right, we show the trace plots for a and ø for 3 different 


41 chains. We see that they do not overlap much with each other. In addition, the numerous black 
42 vertical lines at the bottom of the plot indicate that HMC had many divergences. 


The problem is caused by the overly diffuse priors, which do not get overwhelmed by the likelihood 


44 because we only have 2 data points. Thus the posterior is also diffuse. We can fix this by using 
45 slightly stronger priors, that keep the parameters close to more sensible values. For example, suppose 
46 we use a ~ N (1,10) and o ~ Expon(1). Now we get the results in Figure 12.14. On the right we see 
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Figure 12.13: Marginals (left) and trace plot (right) for the univariate Gaussian using the diffuse prior. 


Black vertical lines indicate HMC divergences. Adapted from Figures 9.9-9.10 of [McE20]. Generated by 
memc_traceplots_ unigauss.ipynb. 


alpha 
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chain2 
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Figure 12.14: Marginals (left) and trace plot (right) for the univariate Gaussian using the sensible prior. 
Adapted from Figures 9.9-9.10 of [McE20]. Generated by mcmc_traceplots_ unigauss.ipynb. 


that the traceplots overlap. On the left, we see that the marginal distributions from each chain have 
support over a reasonable interval, and have a peak at the “right” place (the MLE for a is 0, and for 
a is 1). And we don’t see any divergence warnings (vertical black markers in the plot). 

Since trace plots of converging chains correspond to overlapping lines, it can be hard to distinguish 
success from failure. An alternative plot, known as a trace rank plot, was recently proposed in 
[Veh+19]. (In [McE20], this is called a trankplot, a term we borrow.) The idea is to compute 
the rank of each sample based on all the samples from all the chains, after burnin. We then plot 
a histogram of the ranks for each chain separately. If the chains have converged, the distribution 
over ranks should be uniform, since there should be no preference for high or low scoring samples 
amongst the chains. 

The trankplot for the model with the diffuse prior is shown in Figure 12.15. (The x-axis is from 1 
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Figure 12.15: Trace rank plot for the univariate Gaussian using the diffuse prior. Adapted from Figures 
9.9-9.10 of [McE20]. Generated by mcmc_traceplots_ unigauss.ipynb. 
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= Figure 12.16: Trace rank plot for the univariate Gaussian using the sensible prior. Adapted from Figures 
== 9.9-9.10 of [McE20]. Generated by mcmc_traceplots_ unigauss.ipynb. 


35 to the total number of samples, which in this example is 1500, since we use 3 chains and draw 500 


samples from each.) We can see that the different chains are clearly not mixing. The trankplot for 
the model with the sensible prior is shown in Figure 12.16; this looks much better. 


= 12.6.2.2 Estimated potential scale reduction (EPSR) 


41 In this section, we discuss a way to assess convergence more quantitatively. The basic idea is this: if 
42 one or more chains has not mixed well, then the variance of all the chains combined together will be 
43 higher than the variance of the individual chains. So we will compare the variance of the quantity of 
44 interest computed between and within chains. 


More precisely, suppose we have M chains, and we draw N samples from each. Let £nm denote 


46 the quantity of interest derived from the n’th sample from the m’th chain. We compute the between 
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12.6. MCMC CONVERGENCE 


and within-sequence variances as follows: 


N & se 7 te 7 a 
B= M-i > (T.m — &..)°, where T.m = N Dota z. = 75 2 Tan (12.97) 
M TE Le 
W= i 2 Sn, Where sh = NLI 2 (Enn — T.m) (12.98) 


The formula for s2, is the usual unbiased estimate for the variance from a set of N samples; W is 
just the average of this. The formula for B is similar, but scaled up by N since it is based on the 
variance of T.m, which are averaged over N values. 

Next we compute the following average variance: 

a N-1 1 

Vt —__W+—B 12.99 

W N (12.99) 

Finally, we compute the following quantity, known as the estimated potential scale reduction 
or R-hat: 


7 V+ 
R4A4/— 12.100 
W ( ) 


In [Veh+19], they recommend checking if R < 1.01 before declaring convergence. 

For example, consider the R values for various samplers for our univariate GMM example. In 
particular, consider the 3 MH samplers in Figure 12.1, and the Gibbs sampler in Figure 12.4. The R 
values are 1.493, 1.039, 1.005 and 1.007. So this diagnostic has correctly identified that the first two 
samplers are unreliable, which evident from the figure. 

In practice, it is recommended to use a slightly different quantity, known as split-R. This can 
be computed by splitting each chain into the first and second halves, thus doubling the number of 
chains M (but halving the number of samples N from each), before computing R. This can detect 
non-stationarity within a single chain. 


12.6.3 Effective sample size 


Although MCMC lets us draw samples from a target distribution (assuming it has converged), the 
samples are not independent, so we may need to draw a lot of them to get a reliable estimate. In 
this section, we discuss how to compute the effective sample size or ESS from a set of (possibly 
correlated) samples. 

To start, suppose we draw N independent samples from the target distribution, and let @ = 
x Se Tn be our empirical estimate of the mean of the quantity of interest. The variance of this 
estimate is given by 


2 i) 1 
> “| ee XO Ven] = ~ (12.101) 


where o? = Y [X]. If the samples are correlated, the variance of the estimate will be higher, as we 
show below. 
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Figure 12.17: Autocorrelation functions for various MCMC samplers for the misture of two 1D Gaussians. 
(a-c) These are the MH samplers in Figure 12.1. (d) This is the Gibbs sampler in Figure 12.4. Generated by 
memc_gmm_ demo.ipynb. 


Recall that for N (not necessarily independent) random variables we have 


N N N N 
y Fa = > Cov [x;, xj] = Siv [x;] + 2 5 Cov [z;, 25] (12.102) 
n=1 i=1 j=1 i=1 1<i<j<N 
Let © = + S £n be our estimate based on these correlated samples. The variance of this estimate 
35 is given by 
1X2 
Viz] = a SS) Cov [2;, 25] (12.103) 
i=1 j=1 
We now rewrite this in a more convenient form. First recall that the correlation of x; and x; is 
4l given by 
C ist j 
T E a (12.104) 
V [aa] V [z;] 
46 Since we assume we are drawing samples from the target distribution, we have V [x;] = 07, where o? 
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12.6. MCMC CONVERGENCE 


is the true variance. Hence 


2 N N 


V [z] = = 5 5 corr [x;, £3] (12.105) 


i=1 j=1 


For a fixed i, we can think of corr |[z;, xj] as a function of j. This will usually decay as j gets further 
from i. As N —> co we can approximate the sum of correlations by 


Co Co 


N 
5 corr [x;, £j] > be corr [x;, 240] = 1+ 25 corr |[£i, Lipe] (12.106) 
j=1 


L=—0o f=1 


since corr [x;,2;] = 1 and corr [#;,7;-<] = corr [#;, xi+e] for lag £ > 0. Since we assume the samples 
are coming from a stationary distribution, the index 7 does not matter. This we can define the 
autocorrelation time as 


p=1+ 25> O) (12.107) 
f= 


where p(£) is the autocorrelation function (ACF), defined as 
p(£) £ corr [x0, £] (12.108) 


The ACF can be approximated efficiently by convolving the signal x with itself. In Figure 12.17, 
we plot the ACF for our four samplers for the GMM. We see that the ACF of the Gibbs sampler 
(bottom right) dies off to 0 much more rapidly than the MH samplers. Intuitively this indicates that 
each Gibbs sample is “worth” more than each MH sample. We quantify this below. 

From Equation (12.105), we can compute the variance of pur estimate in terms of the ACF as 


follows: Y [|z] = fe Din p= 2 p. By contrast, the variance of the estimate from independent 
samples is V [ĉ] = %. So we see that the variance is a factor p larger when there is correlation. We 
therefore define the effective sample size of our set of samples to be 
N N 
Ng ê (12.109) 


e p 14281 08) 


In practice, we truncate the sum at lag L, which is the last integer at which p(L) is positive. Also, if 
we run M chains, the numerator should be NM, so we get the following estimate: 


a NM 
Ne = asl oak 
1+2) Ê(£) 


In [Veh+19], they propose various extensions of the above estimator, such as using rank statistics, 
to make the estimate more robust. 


(12.110) 


12.6.4 Improving speed of convergence 


There are many possible things you could try if the R value is too large, and/or the effective sample 
size is too low. Here is a brief list: 
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Neal’s funnel Centered parameterization Non-centered parameterization 
10 


—10 


(a) () (c) 


Figure 12.18: Neal’s funnel. (a) Joint density. (b) HMC samples from centered representation. (c) HMC 
samples from non-centered representation. Generated by neals_ funnel.ipynb. 


e Try using a non-centered parameterization (see Section 12.6.5). 


Try sampling variables in groups or blocks (see Section 12.3.7). 


Try using Rao-Blackwellisation, i.e., analytically integrating out some of the variables (see 
Section 12.3.8). 


e Try adding auxiliary variables (see Section 12.4). 
e Try using adaptive proposal distributions (see Section 12.2.3.5). 
More details can be found in [Rob+18]. 


12.6.5 Non-centered parameterizations and Neal’s funnel 


29 A common problem that arises when applying sampling to hierarchical Bayesian models is when a 


set of parameters at one level of the model have a tight depenendence on parameters at the level 


31 above. We show some practical examples of this in the hierarchical Gaussian 8-schools example in 
32 Section 3.7.2.2 and the hierarchical radon regression example in Section 15.5.3.2. Here, we focus on 


the following simple toy model that captures the essence of the problem: 


v ~ N(0,3) (12.111) 
x ~ N (0, exp(v)) (12.112) 


The corresponding joint density p(x, v) is shown in Figure 12.18a. This is known Neal’s funnel, 
named after [Nea03]. It is hard for a sampler to “descend” in the narrow “neck” of the distribution, 
corresponding to areas where the variance v is small [BG13]. 

Fortunately, we can represent this model in an equivalent way that makes it easier to sample from, 
providing we use a non-centered parameterization [PR03]. This has the form 


v ~ N(0,3) (12.113) 
z ~ N (0,1) (12.114) 
x = zexp(v) (12.115) 
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This is easier to sample from, since p(z,v) is a product of 2 independent Gaussians, and we can 
derive x deterministically from these Gaussian samples. The advantage of this reparameterization is 
shown in Figure 12.18. A method to automatically derive such reparameterizations is discussed in 
[GMH20]. 


12.7 Stochastic gradient MCMC 


Consider an unnormalized target distribution of the following form: 


N 
(8) x p(0, D) = po(0) | | p210) (12.116) 
n=1 
where D = (a1,...,@y). Alternatively we can define the target distribution in terms of an energy 


function (negative log joint) as follows: 
p(0, D) x exp(—E(8)) (12.117) 


The energy function can be decomposed over data samples: 


E0) = X £n (0) (12.118) 


1 
En(0) = — log p(£n|0) — 5 log po (0) (12.119) 


Evaluating the full energy (e.g., to compute an acceptance probability in the Metropolis Hastings 
algorithm, or to compute the gradient in HMC) takes O(N) time, which does not scale to large data. 
In this section, we discuss some solutions to this problem. 


12.7.1 Stochastic Gradient Langevin Dynamics (SGLD) 
Recall from Equation (12.83) that the Langevin diffusion SDE has the following form 
dO, = —VE(@,)dt + V2dW, (12.120) 


where dW; is a Wiener noise (also called Brownian noise) process. In discrete time, we can use the 
following Euler approximation: 


O14 ~ 0; = mV E(Oz) + vy 2m N (0, 1) (12.121) 


Computing the gradient g(@;) = VE(0;) at each step takes O(N) time. We can compute an 
unbiased minibatch approximation to the gradient term in O(B) time using 


A N N B 
ĝ(0:) = B 2 VEn (0+) = -R (= V log p(æn|0:) + FV rt) (12.122) 
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where 6; is the minibatch at step t. This gives rise to the following approximate update: 


Or41 = Oi — mG(Az) + V/2mN (0, 1) (12.123) 


This is called stochastic gradient Langevin dynamics or SGLD [Well11]. The resulting update 
step is identical to SGD, except for the addition of a Gaussian noise term. (See [Neg+21] for some 
recent analysis of this method; they also suggest setting m, x N~?/3.) 


12.7.2 Preconditionining 


As in SGD, we can get better results (especially for models such as neural networks) if we use 
preconditioning to scale the gradient updates. In [PT13], they use the Fisher information matrix 
(FIM) as the preconditioner; this method is known as Stochastic Gradient Riemannian Langevin 
Dynamics or SGRLD. 

Unfortunately, computing the FIM is often hard. In [Li+16], they propose to use the same kind of 
diagonal approximation as used by RMSprop; this is called preconditioned SGLD. An alternative 
is to use an Adam-like preconditioner, as proposed in [KSL21]. This is called SGLD-Adam. For 
more details, see [CSN21]. 


12.7.3 Reducing the variance of the gradient estimate 


The variance of the noise introduced by minibatching can be quite large, which can hurt the 
performance of methods such as SGLD [BDM18]. In [Bak+17], they propose to reduce the variance 
of this estimate by using a control variate estimator; this method is therefore called SGLD-CV. 
Specifically they use the following gradient approximation: 


Vev€ (01) = VE(8) + X y (VEn(6.) = VEn(Ô)) (12.124) 


nESt 


Here Ô is any fixed value, but it is often taken to be an approximate MAP estimate (e.g., based on 
one epoch of SGD). The reason Equation (12.124) is valid is because the terms we add and subtract 


31 are equal in expectation, and hence we get an unbiased estimate: 

= N 

: t [Pee] =ve(6)+E|> > (VEn (6%) = vex) (12.125) 
AS nESt 

5 = VE(Î) + VE(0+) — VE(Ô) = VE (0+) (12.126) 
37 


E IS IS TE IS ISIS IS |S IS 
N IO lo Te |e N ie oO Io œ 


37 Note that the first term, VE (ô) = aan VEn(8), requires a single pass over the entire dataset, but 
38 only has to be computed once (e.g., while estimating 6). 


One disadvantage of SGLD-CV is that the reference point Ô has to be precomputed, and is then 


= fixed. An alternative is to update the reference point online, by performing periodic full batch 
= estimates. This is called SVRG-LD [Dub+16; Cha+18], where SVRG stands for stochastic variance 


= reduced gradient, and LD stands for Langevin Dynamics. If we use 6, to denote the most recent 
= snapshot (reference point), the corresponding gradient estimate is given by 
` : N - 
Poors (0) = VE(6:) + = D (VEn (61) - VEn (8:)) (12.127) 
nEs: 
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12.7. STOCHASTIC GRADIENT MCMC 


We recompute the snapshot every T steps (known as the epoch length). See Algorithm 30 for the 
pseudo-code. 


Algorithm 30: SVRG Langevin Descent 


1 Initialize 09 
fort=1:Tdo 
if t mod 7 = 0 then 


2 
3 
4 0 = 0; 
me N pe 
5 g= Dan En(0) 
Sample minibatch B; € {1,..., N} 


7 | aoe X Enes, (VEn (61) = VE. (8)) 
8 | O41 = 0i — mage + V2mN (0, 1) 


[e>] 


The disadvantage of SVRG is that it needs to perform a full pass over the data every T steps. An 
alternative approach, called SAGA-LD [Dub+16; Cha+18] (which stands for stochastic averaged 
gradient acceleration), avoids this by storing all N gradient vectors, and then doing incremental 
updates. Unfortunately the memory requirements of this algorithm usually make it impractical. 


12.7.4 SG-HMC 


We discussed Hamiltonian Monte Carlo (HMC) in Section 12.5, which uses auxiliary momentum 
variables to improve performance over Langevin MC. In this section, we discuss a way to speed it up 
by approximating the gradients using minibatches. This is called called SG-HMC [CFG14; ZG21], 
where SG stands for “stochastic gradient”. 

Recall that the leapfrog updates have the following form: 


Vi41/2 = Ut — SVE (61) (12.128) 
O44 = 0: + NVt+1/2 = O, + Nut — SVE (61) (12.129) 
VUt41= Ut41/2 = SVE (Ort) = Ut — SVE (61) = SVE (Ort) (12.130) 


We can replace the full batch gradient with a stochastic approximation, to get 


2 
O41 = Oi + Nv — T 99:1) (12.131) 


Vt+1 = Ut — 5:9(9:.€) = 5909141, 6141/2) (12.132) 
where €, and 41/2 are independent sources of randomness (e.g., batch indices). In [ZG21], they 
show that this algorithm (even without the MH rejection step) provides a good approximation to the 
posterior (in the sense of having small Wasserstein-2 distance) for the case where the energy functon 
is strongly convex. Furthermore, performance can be considerably improved if we use the variance 
reduction methods discussed in Section 12.7.3. 
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12.7.5 Underdamped Langevin Dynamics 


The underdamped Langevin dynamics (ULD) has the form of the following SDE [CDC15; 
LMS16; Che+18a; Che+18d]: 


dé, = vidt 
dv = —g(0,)dt — yu,dt + 4/ 2ydW 


where g(9;) = VE(9;) is the gradient or force acting on the particle, y > 0 is the friction parameter, 
and dW, is Wiener noise. 

Equation (12.133) is like the Langevin dynamics of Equation (12.83) but with an added momentum 
term v. We can solve the dynamics using various integration methods. It can be shown (see e.g., 
[LMS16]) that these methods are accurate to second order, whereas solving standard (overdamped) 
Langevin is only accurate to first order, and thus will require more sampling steps to achieve a given 
accuracy. 


(12.133) 


12.8 Reversible jump (trans-dimensional) MCMC 


Suppose we have a set of models with different numbers of parameters, e.g., mixture models in which 
the number of mixture components is unknown. Let the model be denoted by m, and let its unknowns 
(e.g., parameters) be denoted by £m E€ Xm (e.g., Xm = R™, where nm is the dimensionality of 
model m). Sampling in spaces of differing dimensionality is called trans-dimensional MCMC. We 
could sample the model indicator m € {1,..., M} and sample all the parameters from the product 
space m, Xm, but this is very inefficient, and only works if M is finite. It is more parsimonious 


to sample in the union space ¥ = UM_ {m} x Xm, where we only worry about parameters for the 


= currently active model. 


The difficulty with this approach arises when we move between models of different dimensionality. 


= The trouble is that when we compute the MH acceptance ratio, we are comparing densities defined 
= on spaces of different dimensionality, which is not well defined. For example, comparing densities on 
= two points of a sphere makes sense, but comparing a density on a sphere to a density on a circle 
= does not, as there is a dimensional mismatch in the two concepts. The solution, proposed by [Gre98] 
= and known as reversible jump MCMC or RJMCMC, is to augment the low dimensional space 
= with extra random variables so that the two spaces have a common measure. This is illustrated in 
= Figure 12.19. 


We give a sketch of the algorithm below. For more details, see e.g., [Gre03; HG12]. 


~, 12.8.1 Basic idea 


39 To explain the method in more detail, we follow the presentation of [And+03]. To ensure a common 
40 measure, we need to define a way to extend each pair of subspaces Xm and Xn to Xm,n = Xm X Um n 
41 and Xn,m = Xn X Un,m. We also need to define a deterministic, differentiable and invertible mapping 


(Bins Um,n) = fram (Ln, Un,m) = CRA (£n, Unm), paren Ens Un,m)) (12.134) 


Invertibility means that 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID [o [A [wo N e 


Io IS ls la le le Ie IE 


IS Is 


12.8. REVERSIBLE JUMP (TRANS-DIMENSIONAL) MCMC 


Bivariate density 


P(X 1X5) 


X4 Compare both 
densities 
point-wise 

Univariate density Uniformly expanded density 
P(x.) P(x 1.x") 
Proposex* 
uniformly 


Figure 12.19: To compare a 1d model against a 2d model, we first have to map the 1d model to 2d space so 
the two have a common measure. Note that we assume the ridge has finite support, so it is integrable. From 
Figure 17 of [And+03]. Used with kind permission of Nando de Freitas. 


Finally, we need to define proposals gn+m(Un,m|N, Ln) and Gm+n(Umn|M, £m). 
Suppose we are in state (n, £n). We move to (Mm, £m) by generating Un,m ~ Inom(-|N,%n), and 
then computing (£m, Um,n) = fn+m(&n;Unjm). We then accept the move with probability 


P(N, £n) q(m|n) dnm(Un,m|n, Zn) 


where £h = f?_,,(@n,Un,m), Jf, 18 the Jacobian of the transformation 


J _ Ofn—+m(Lm, Um,n) 
ta deals) 


(12.137) 
and | det J| is the absolute value of the determinant of the Jacobian. 


12.8.2 Example 


Let us consider an example from [AFD01]. They consider an RBF network for nonlinear regression 
of the form 


k 
N ajK (la — mll) + Be + Bo +e (12.138) 


j=1 


= 

— 

2 
II 


where K() is some kernel function (e.g., a Gaussian), k is the number of such basis functions, and e€ 
is a Gaussian noise term. If k = 0, the model corresponds to linear regression. 
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= Figure 12.20: Fitting an RBF network to some 1d data using RJMCMC. (a) Prediction on train set. (b) 
30 Prediction on test set. (c) Plot of p(k|D) vs iteration. (d) Final posterior p(k|D). Adapted from Figure 4 of 
31 [AFDO1]. Generated by rjmcmc_rbf, written by Nando de Freitas. 


They fit this model to the data in Figure 12.20(a). The predictions on the test set are shown 


35 in Figure 12.20(b). Estimates of p(k|D), the (distribution over the) number of basis functions, are 
36 shown in Figure 12.20(c) as a function of the iteration number; the posterior at the final iteration is 


shown in Figure 12.20(d). There is clearly the most posterior support for k = 2, which makes sense 


38 given the two “bumps” in the data. 


To generate these results, they consider several kinds of proposal. One of them is to split a current 


40 basis function p into two new ones using 


fly = [b— Un n410, H2 = H + Un n418 (12.139) 


43 where a is a parameter of the proposal, and Un,m is sampled from some distribution (e.g., uniform). 
44 To ensure reversibility, they define a corresponding merge move 


_ Hit be 


12.140 
; (12.140) 
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where pı is chosen at random, and pg is its nearest neighbor. To ensure these moves are reversible, 
we require ||u1 — u2|| < 26. 
The acceptance ratio for the split move is given by 


p(k +1, uk+1) | 1/(k +1) 1 \ 
Asplit = min 4 1, x x x |det Jspli 12.141 
p { p(k, Hk+1) 1/k P(Un,m) | pi 
where 1/k is the probability of choosing one of the k bases uniformly at random. The Jacobian is 
(m, u2) ( 1 5) 
J spit = —>———~ = det 12.142 
plit Olp, Unm) =f B ( ) 


so |det Jspuit| = 28. The acceptance ratio for the merge move is given by 


(k— 1,1) | 1/(k=1) 
p(k, Ux) 1/k 


where |det Jmerge| = 1/(28). 


Ae min f1, ž x |det Sanl) (12.143) 


Algorithm 31: Generic reversible jump MCMC (single step) 
Sample u ~ U (0,1) 
If u < bk 
then birth move 
else if u < (by + dp) then death move 
else if u < (bk + dk + sk) then split move 
else if u < (by + dk + Sk + Mg) then merge move 
else update parameters 


Noa Bb WN 


The overall pseudo-code for the algorithm, assuming the current model has index k, is given in 
Algorithm 31. Here bx is the probability of a birth move, dx is the probability of a death move, sk 
is the probability of a split move, and mz, is the probability of a merge move. If we don’t make a 
dimension-changing move, we just update the parameters of the current model using random walk 
MH. 


12.8.3 Discussion 


RJMCMC algorithms can be quite tricky to implement. If, however, the continuous parameters can 
be integrated out (resulting in a method called collapsed RJMCMC), much of the difficulty goes 
away, since we are just left with a discrete state space, where there is no need to worry about change 
of measure. For example, if we fix the centers u; in Equation (12.138) (e.g., using samples from the 
data, or using K-means clustering), we are left with a linear model, where we can integrate out the 
parameters. All that is left to do is sample which of these fixed basis functions to include in the 
model, which is a discrete variable selection problem. See e.g., [Den+02] for details. 

In Chapter 31, we discuss Bayesian nonparametric models, which allow for an infinite number of 
different models. Surprisingly, such models are often easier to deal with computationally (as well as 
more realistic, statistically) than working with a finite set of different models. 
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(a) (b) 


Figure 12.21: (a) A peaky distribution. (b) Corresponding energy function. Generated by simu- 
lated_ annealing 2d_demo.ipynb. 


12.9 Annealing methods 


Many distributions are multimodal and hence hard to sample from. However, by analogy to the way 
metals are heated up and then cooled down in order to make the molecules align, we can imagine 
using a computational temperature parameter to “smooth out” a distribution, gradually cooling it to 
recover the original “bumpy” distribution. We first explain this idea in more detail in the context of 
an algorithm for MAP estimation. We then discuss extensions to the sampling case. 


12.9.1 Simulated annealing 


= Tn this section, we discuss the simulated annealing algorithm [KJV83; LA87], which is a variant of 
~~ the Metropolis Hastings algorithm which is designed to find the global optimum of blackbox function. 


(Other approaches to blackbox optimization are discussed in Section 6.9.) 
Annealing is a physical process of heating a solid until thermal stresses are released, then cooling it 


— very slowly until the crystals are perfectly arranged, acheiving a minimum energy state. Depending 


— on how fast or slow the temperature is cooled, the results will have worse or better the quality. We 
= can apply this approach to probability distributions, to control the number of modes (low energy 
— states) that they have, by defining 


pr(x) = exp(—E(x)/T) (12.144) 


40 where T is the temperature, which is reduced over time. As an example, consider the peaks function: 


2 y2 L e-+)? (12.145) 


—r?— 2 T -r 
p(z, y) œ [3(1 — 2)e-* “+0? — 10(Ë — 2 — ye 5 


45 This is plotted in Figure 12.21a. The corresponding energy is in Figure 12.21b. We plot annealed 
46 versions of this distribution in Figure 12.22. At high temperatures, T >> 1, the surface is approximately 
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T=10.00 T=2.00 


(c) (d) 


Figure 12.22: Annealed version of the distribution in Figure 12.21a at different temperatures. Generated by 
simulated_annealing_ 2d_demo.ipynb. 


flat, and hence it is easy to move around (i.e., to avoid local optima). As the temperature cools, 
the largest peaks become larger, and the smallest peaks disappear. By cooling slowly enough, it is 
possible to “track” the largest peak, and thus find the global optimum (minimum energy state). This 
is an example of a continuation method. 

In more detail, at each step, we sample a new state according to some proposal distribution 
a’ ~ q(-|a;). For real-valued parameters, this is often simply a random walk proposal centered on 
the current iterate, a’ = a; + 41, where €:4; ~ N (0, X). (The matrix © is often diagonal, and 
may be updated over time using the method in [Cor+87].) Having proposed a new state, we compute 
the acceptance probability 


Qt+1 = exp (—(E(a") = E(a1z))/Tt) (12.146) 


where T; is the temperature of the system. We then accept the new state (i.e., set 41 = x’) with 
probability min(1, a441), otherwise we stay in the current state (i.e., set 41 = £+). This means that 
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Figure 12.23: Simulated annealing applied to the distribution in Figure 12.21a. (a) Temperature vs iteration 
and probability of each visited point vs iteration. (b) Visited samples, superimposed on the target distribution. 
The big red dot is the highest probability point found. Generated by simulated_ annealing 2d_demo.ipynb. 


if the new state has lower energy (is more probable), we will definitely accept it, but if it has higher 
energy (is less probable), we might still accept, depending on the current temperature. Thus the 
algorithm allows “downhill” moves in probability space (uphill in energy space), but less frequently 
as the temperature drops. 

The rate at which the temperature changes over time is called the cooling schedule. It has been 
shown [Haj88] that if one cools according to a logarithmic schedule, T; « 1/log(t + 1), then the 
method is guaranteed to find the global optimum under certain assumptions. However, this schedule 
is often too slow. In practice it is common to use an exponential cooling schedule of the form 
Tr+1 = yI:, where y € (0, 1] is the cooling rate. Cooling too quickly means one can get stuck in a 
local maximum, but cooling too slowly just wastes time. The best cooling schedule is difficult to 
determine; this is one of the main drawbacks of simulated annealing. 

In Figure 12.23a, we show a cooling schedule using y = 0.9. If we combine this with a Gaussian 
random walk proposal with ø = 10 to the peaky distribution in Figure 12.21a, we get the results 
shown in Figure 12.23 and Figure 12.23b. We see that the algorithm concentrates its samples near 
the global optimum (the peak on the middle right). 


12.9.2 Parallel tempering 


Another way to combine MCMC and annealing is to run multiple chains in parallel at different 
temperatures, and allow one chain to sample from another chain at a neighboring temperature. In 


ag this way, the high temperature chain can make long distance moves through the state space, and have 


this influence lower temperature chains. This is known as parallel tempering. See e.g., [ED05; 
Kat+06] for details. 
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1 3 Sequential Monte Carlo inference 


13.1 Introduction 


In this chapter, we discuss sequential Monte Carlo or SMC algorithms, which can be used to 
sample from a sequence of related probability distributions. SMC is most commonly used to solve 
filtering in state-space models (SSM, Chapter 29), but it can also be applied to other problems, such 
as sampling from a static (but possibly multi-modal) distribution, or for sampling rare events from 
some process. 

Our presentation is based on the excellent tutorial [NLS19], and differs from traditional presenta- 
tions, such as [Aru+02], by emphasizing the fact that we are sampling sequences of related variables, 
not just computing the filtering distribution of an SSM. This more general perspective will let us 
tackle static estimation problems, as we will see. For another good introduction to SMC, see [DJ11]. 
For a more formal (measure theoretic) treatment of SMC, using the Feynman-Kac formalism, see 
[CP20b]. 


13.1.1 Problem statement 


In SMC, the goal is to sample from a sequence of related distributions of the form 


T(Z1:4) = Zines) (13.1) 


for t = 1: T, where Ją is the unnormalized target distribution, m, is the normalized version, 
and zı: are the random variables of interest. In some applications (e.g., filtering in an SSM), we 
care about each intermediate marginal distribution, 7;,(z,), for t = 1 : T; this is called particle 
filtering. (The word “particle” just means “sample”.) In other applications, we only care about the 
final distribution, 77(zr), and the intermediate steps are introduced just for computational reasons; 
this is called an SMC sampler. We briefly review both of these below, and go into more detail in 
later sections. 


13.1.2 Particle filtering for state-space models 


An important application of SMC is to sequential (online) inference (state estimation) in SSMs. As 
an example, consider a Markovian state-space model with the following joint distribution: 


T 
wr (zi.r) X p(Zir, YT) = p(z1)plyı|z1) J [p(z )pelze) (13.2) 
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Noisy observations from hidden trajectory 


44 — state space + 44 — filtered + 


(a) (b) 


Figure 13.1: Illustration of particle filtering (using the dynamical prior as the proposal) applied to a 2d 
nonlinear dynamical system. (a) True underlying state and observed data. (b) PF estimate of the posterior 
mean. Generated by bootstrap _filter_ spiral. ipynb. 


A common choice is to define the unnormalized target distribution at step t to be 


t 


Plie) = (Zit, yt) = P(z1)P(yilzr) [| p(2s|zs-1)P(ysl2s) (13.3) 


~ Note that this a distribution over an (ever growing) sequence of latent variables. However, we often 
= only care about the most recent marginal of this distribution, in which case we just need to compute 
~ 44(24), which avoids having to store the full history. 


For example, consider the following 2d nonlinear tracking problem (the same one as in Sec- 
~~ tion 8.5.4.1): 
p(2t|Ze-1) = N (ztl f (2-1), 9D 
P(yel2e) = N (ylz rI) (13.4) 
f(z) = (4 + Asin(z2), z2 + A cos(21)) 
38 where A is the step size of the underlying continuous system, q is the variance of the system 


39 noise, and r is the variance of the observation noise. (We treat A, q and r as fixed constants; see 
40 Supplementary Section 13.1.3 for a discussion of joint state and parameter estimation.) The true 
41 underlying state trajectory, and the corresponding noisy measurements, are shown in Figure 13.1a. 
42 The posterior mean estimate of the state, computed using 2000 samples in a simple form of SMC 
43 called the bootstrap filter (Section 13.2.3.1), is shown in Figure 13.1b. 


Particle filtering can also be applied to non-Markovian models, where z; may depend on all 


45 the past hidden states, z1.,-1, and y; depends on the current z; and possibly also all the past hidden 
46 states, Z1:+—1, and optionally the past observatiobns, y;.,-1. In this case, the unnormalized target 
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distribution at step t is 


nlzi) = p(21)P(yilzr) | | p(2sl21:s—1)P(Ysl 21:5) (13.5) 


s=l 


For example, consider a 1d Gaussian sequence model where the dynamics are first-order Markov, but 
the observations depend on the entire past sequence (this is example 1.2.1 from [NLS19]): 


p(Zt|Z1:e-1) = N (2/6241, 9) 


t 13.6 
P(yel21:t) =N (yl X pzs r) a 
s=1 


If we set 8 = 0, we get p(y:|214) = N (yr|zt, r) (where we define 0° = 1), so the model becomes a 
linear-Gaussian SSM. As 8 gets larger, the dependence on the past increases, making the inference 
problem harder. (We will revisit this example below.) 


13.1.3 SMC samplers for static parameter estimation 


Now consider the problem of parameter estimation from a fixed dataset, D = {yn :n=1: 
Np}. We suppose the observations are conditionally iid, so the posterior has the form p(z|D) œ 
p(z) EEA p(Yn|z), where z is the unknown parameter. It is not immediately obvious how to 
approximate p(z|D) using SMC, since we just have one distribution. However, we can convert 
this into a sequential inference problem in several different ways. One approach, known as data 
tempering, defines the (marginal) target distribution at step t as e(z) = p(zt)p(yialz). In 
this case, the number of time steps T is the same as the number of data samples, Np. Another 
approach, known as likelihood tempering, defines the (marginal) target distribution at step t as 
nlzi) = plzi)p(D| z1), where 0 = 7%) <- -: < Tr =1 is a temperature parameter. In this case, the 
number of steps T depends on how quickly we anneal the distibution from the initial prior p(z1) to 
the final target p(zr)p(D|zr). 

Once we have defined the marginal target distributions (z+), we need a way to expand this to a joint 
target distribution over a sequence of variables, 4(21.4), so the distributions become connected to each 
other. We explain how to do this in Section 13.6. We can then treat the model as an SSM and apply 
particle filtering. At the end, we extract the final joint target distribution, ¥r(21.7) = p(Z1:r)p(P|zr), 
from which we can compute the marginal target distribution Yr(zr) = p(zr, D), from which we can 
get the posterior p(z|D) by normalizing. We give the details in Section 13.6. 


13.2 Particle filtering 


In this section, we cover the basics of SMC for state space models, culiminating in a method known 
as the particle filter. 


13.2.1 Importance sampling 


We start by reviewing the self-normalized importance sampling method (SNIS, Section 11.5), which 
is the foundation of the particle filter. 
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Suppose we are interested in estimating the expectation of some function y+ with respect to a 
target distribution m+, which we denote by 


mo) Ê En, [pelz] = f WED a (erates (13.7) 


where Z; = f 4t(Z14)dz1... Suppose we use SNIS with proposal q;(21::). We then get the following 
approximation: 


ee 
T(E) S Sa 2 li) Oe (414) (13.8) 
to S i=l 
where z$. i qı are independent samples from the proposal, w} are the unnormalized weights 
defined by 
ai = elie) (13.9) 
qt (zia) 
and Z; is the approximate normalization constant defined by 
rS 
ĉĉ, 2 N, 2 (13.10) 
To simplify notation, let us define the normalized weights by 
we (13.11) 
y j 
Then we can write 
Ns 
o kolz) © 2 Wiel (zia) (13.12) 


Alternatively, instead of computing the expectation of a specific target function, we can just approxi- 
mate the target distribution itself, using a sum of weighted samples: 


i(Zi:t )= Sm Zin — zia) E elza) (13.13) 


The problem with importance sampling when applied in the context of sequential models is that 


44 the dimensionality of the state space is very large, and increases with t. This makes it very hard to 


define a good proposal that covers the high probability regions, resulting in most samples getting 


46 negligible weight. In the sections below, we discuss solutions to this problem. 
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13.2.2 Sequential importance sampling 


In this section, we discuss sequential importance sampling or SIS, in which the proposal has 
the following autoregressive structure: 


Qe(Z1:4) = Ge—1 (214-1) Ge (Ze |Z14-1) (13.14) 


We can obtain samples from q;_1 (214-1) by reusing the z},,_, samples, which we then extend by 
one step by sampling from the conditional q(z;|z4.,_,). We can think of this as “growing” the chain 
(sequence of states). The unnormalized weights can be computed recursively as follows: 


aoe Ulz) = Fe-1 (214-1) Ve(Z10) (13.15) 
a(Z1:4) Ve—1 (210-1) (Ze |Z14—-1)Ge—-1 (Z1:4-1) 
= ¥4—1 (214-1) Ve(Z1:2) (13.16) 
Ge—1 (212-1) Ve-1 (414-1) Ge (Zt|21:e-1) . 
. Ve(Z1:t) 
peas E 13.17 
t1 (Zit 2 a a E 
The ratio factors are sometimes called the incremental importance weights: 
on Falzi) (13.18) 


Ve—1 (214-1) Gt (Zt |21:¢-1) 


See Algorithm 32 for pseudocode for the resulting SIS algorithm. (In practice we compute the weights 
in log-space, and convert back using the log-sum-exp trick.) 

Note that, in the special case of state space models, the weight computation can be further 
simplified. In particular, suppose we have 


Ulz) = P(Z1t, Yt) = P(Yel zit) P(t] 214-1) P(Z1:t-1, Yr-t-1) (13.19) 
= p(yt|21:)P( Ze] 212-1) Ve-1 (214-1) (13.20) 


Then the incremental weight is given by 


Z1:)DP( 2t|Z14-1) Ve—1 (Z14— Z1:2)P( Zt|Z14— 
deg a ut)P(t| 214-1) Ve-1 (21:t-1) _ P(yil ut)P(2t|Z14-1) (13.21) 


Ve—1 (21:21) Gt (Z| 214-1) qt (Ze|21:4-1) 


Unfortunately SIS suffers from a problem known as weight degeneracy or particle impover- 
ishment, in which most of the weights become very small (near zero), so the posterior ends up being 
approximated by a single particle. This is illustrated in Figure 13.2a, where we apply SIS to the 
non-Markovian example in Equation (13.6) using Ns = 5 particles. The reason for degeneracy is 
that each particle has to “explain” (generate) the entire sequence of observations. Each sequence of 
guessed states becomes increasingly improbable over time, due to the product of likelihood terms, 
and the differences between the weights of each hypothesis will grow exponentally. Of course, there 
has to be a best sequence amongst the set of candidates, so when we normalize the weights, the best 
one will get weight 1 and the rest will get weight 0. But this is a waste of most of the particles. We 
discuss a solution to this in Section 13.2.3. 
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Algorithm 32: Sequential importance sampling (SIS) 


oer ; ~i — a(i ; 3 Ns yi i 
1 Initialization: zi ~ qı (21), wi} = Ben, t= = a #1 (21) = X; Wid(z1 — 24) 
2 fort =2:Tdo 
for i = 1 : N, do 
Sample 2; ~ qe(zelZi—1) 
: : i (zi) 
5 Compute incremental weight aj = ye e) 
6 Compute unnormalized weight w; = wW}_,a} 
7 | Compute normalized weights Wý = =a for i= 1+:N, 
j Wt 
8 Compute MC posterior #:(Z1.4) = 0), Wô (zit — zi) 


prior co Moose. P(2e4 | Yata) 


{za} 


oo 
resample has 


4 7 - proposal ANN qlz | zia) 
o è 


weighting pat 
on : ; . . | «+ —“"~ deeste ai `~- Ply, | 24) 
posterior 
; g e © 0 o ®® « Pz! y::) 
Hi ® ® 4 5 6 7 
(a) (e) 


= Figure 13.2: (a) Illustration of weight degeneracy for SIS applied to the model in Equation (13.6). with 
= parameters (ġ,q,8B,r) = (0.9, 10.0,0.5,1.0). We use T = 6 steps and N, = 5 samples. We see that as t 
2" increases, almost all the probability mass concentrates on particle 3. Generated by sis_vs_smc.ipynb. Adapted 
31 from Figure 2 of [NLS19]. (b) Illustration of the bootstrap particle filtering algorithm. 


34 13.2.3 Sequential importance sampling with resampling 


In this section, we describe sequential importance sampling with resampling (SISR). The 
basic idea is this: instead of “growing” all of the old particle sequences by one step, we first select the 
N, “fittest” particles, by sampling from the old posterior, and then we let these survivors grow by 
one step. 

In more detail, at step t, we sample from 


ge? (zit) = fea (214-1) Ge (ZelZ10—1) (13.22) 

243 where 7:~1(21:4-1) is the previous weighted posterior approximation. By contrast, in SIS, we sample 
from 

Ge (zia) = Qe (210-1) ae (| 214-1) (13.23) 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


S IO lœ IN IQ Jor e w N e 


Io a ls la Ie le Ie IE 


IS le 


13.2. PARTICLE FILTERING 


Algorithm 33: Sequential importance sampling with resampling (SISR) 


1 Initialization: zf ~ q:(z1), wi = Ben, Wi = — a ilz) = OM, Wid(z, — z4) 
J 


2 fort =2:T do 

3 Compute ancestors a} = resample(w;'\") 

4 Select 231° = permute(a t^s, 27:1") 

5 Reset unnormalized weights ÙA = 1/N. 

6 for i = 1 : N, do 

7 Sample z; ~ 4¢(2¢|21:4-1) 

; : ai yt Felzia) 
8 Compute unnormalized weight w; = a; = nae ue 
9 Compute normalized weights W} = =a for i = 1 : Ns 
jwt 
A Ns u i 

10 | Compute MC posterior 7;(21:4) = )0j2) Wid(21: — 214) 


We can sample from Equation (13.22) in two steps. First we resample N, samples from 711 (21-41) 
to get a uniformly weighted set of new samples z/.,_,. (See Section 13.2.4 for details on how to do 
this.) Then we extend each sample using zi ~ q:(z:|z}.4_,), and concatenate zł to 24.._4, 

After making a proposal, we compute the unnormalized weights. We use the standard SNIS 
method, except we “pretend” that the proposal is given by ¥~-1(24.4-1)q(zi|Z}.4_1) even though 
we used 7-1(24.4-1)q(2i|2}.4_1)- The intuitive reason why this is valid is because the previous 
weighted approximation, 7;_1(z4.,_1), was an unbiased estimate of the previous target distribution, 
V-1(Z1:4-1). (See e.g., [CP20b] for more theoretical details.) We then compute the unnormalized 
weights, which are the same as the incremental weights, since the resampling step sets w!_, = 1. We 
then normalize these weights and compute the new approximationg to the target posterior 7;(z1:1). 
See Algorithm 33 for the pseudocode. 


13.2.3.1 Bootstrap filter 


We now consider a special case of SISR, in which the model is an SSM, and the proposal distribution 
is equal to the dynamical prior: 


Qe(Ze|Z14-1) = p(2t|21:4-1) (13.24) 
In this case, the corresponding incremental weight in Equation (13.21) simplifies to 


oe elated ee ala 032s) 


This special case is called the bootstrap filter [Gor93] or the survival of the fittest algorithm 
[KKR95]. (In the computer vision literature, this is called the condensation algorithm, which 
stands for “conditional density propagation” [IB98].) See Figure 13.2b for an illustration of how this 
algorithm works, and Figure 13.1b for some sample results on real data. 
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Figure 13.3: (a) Illustration of diversity of samples in SMC applied to the model in Equation (13.6). (b) 
Illustration of the path degeneracy problem. Generated by sis_vs_smc.ipynb. Adapted from Figure 3 of 
[NLS19]. 


The bootstrap filter is useful for models where we can sample from the dynamics, but cannot 
evaluate the transition model pointwise. This occurs in certain implicit dynamical models, such as 
those defined using differential equatons (see e.g., [IBK06]); such models are often used in epidemiology. 
However, in general it is much more efficient to use proposals that take the current evidence y; into 
account. We discuss ways to approximate such “locally optimal” proposals in Section 13.3. 


13.2.3.2 Path degeneracy problem 


In Figure 13.3a we show how particle filtering can result in a much more diverse set of active particles, 
with more balanced weights when applied to the non-Markovian example in Equation (13.6). 
While particle filtering does not suffer from weight degeneracy, it does suffer from another problem 
known as path degeneracy. This refers to the fact that the number of particles that “survive” (have 
non-negligible weight) over many steps may drop rapidly over time, resulting in a loss of diversity 
when we try to represent the distribution over the past. We illustrate this in Figure 13.3b, where 
we only include arrows for samples that have been resampled at each step up until the final step. 
We see that we have N, = 5 identical copies of z! in the final set of surviving sequences. (The 


31 time at which all the paths meet at a common ancestor, when tracing backwards in time, is known 


as the coalescence time.) We discuss some ways to ameliorate this issue in Section 13.2.4 and 


Section 13.2.5. 
35 13.2.3.3 Estimating the normalizing constant 
z- We can use particle filtering to approximate the normalization constant Zp = p(y1:7) = IŁ, PY Yit-1) 
as follows: 
T 
ĉr=]|[Ż (13.26) 
t=1 


where, from Equation (13.10), we have 
a — 
A= a 2 i = Êi (Z121) (13.27) 
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13.2. PARTICLE FILTERING 


0 1 2 3 4 


Figure 13.4: Illustration of how to sample from the empirical CDF P(x) = S>*_, W"I(a > n) shown in black. 
The height of step n is Wn. If U™ picks step n, then we set the ancestor of m to be n, i.e., A™ =n. In this 
example, Al? = (1,2,2). Adapted from Figure 9.3 of [CP20b]. 


where 
Ns mi 
—_——_— Pae W 
4;/Z¢1 = Zee (13.28) 
a Üi 


This estimate of the marginal likelihood is very useful for tasks such as parameter estimation. 


13.2.4 Resampling methods 


In this section, we discuss various resampling methods, which can help reduce weight degeneracy and 
path degeneracy. 


13.2.4.1 Multinomial resampling 


The simplest approach to resampling is known as multinomial resampling. This works as follows. 
First we form the cumulative distribution from the weights W; Ẹ*, as illustrated by the staircase in 
Figure 13.4. Then then we sample N, uniform random variables, uf ~ {0,1}. Finally, we see which 
bin (interval) uê lands in; if it falls in bin a, we assign the new sample z}_, to be the same as the old 
z;_,. We say that a is the ancestor of i. For precisely, we say a is the ancestor of sample i if 


a—1 a 
S wji <u <> wi, (13.29) 
j=1 j=1 


See Listing 13.1 for some Python code. 


Listing 13.1: Multinomial resampling 
def multinomial_resampling(w): 
w. shape [0] 
np.random.rand(N) 


2 
ou 
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bins = np.cumsum(w) 

ancestors = np.digitize(u, bins) 

return ancestors 

Although this is a simple method, it can introduce a lot of variance into the representation of the 
distribution. For example, suppose all the weights are equal, W” = 1/N. Let W” = yo I(A™ =n) 
be the number of “offspring” for particle n (i.e., the number of times this particle is chosen in the 
resampling step). We have W” ~ Bin(N,1/N), so P(W” = 0) = (1—1/N)% = e~! x 0.37. So there 
is a 37% chance that any given particle will disappear even though they all had the same initial 
weight. In the sections below, we discuss some low variance resampling methods. 


13.2.4.2 Stratified resampling 


A simple approach to improve on multinomial resampling is to use stratified resampling, in 
which we divide the unit interval into Ns strata, (0, 1/Ns), (1/Ns,2/Ns;), up to (1 —1/N,,1). We 
then generate uê ~ Unif((i — 1)/Ns,i/Ns) and derive the corresponding ancestor indexes using 
Equation (13.29). See Listing 13.2 for some Python code. 


Listing 13.2: Stratified resampling 
def stratified_resampling(w): 
N = w.shape [0] 
u = (np.arange(N) + np.random.rand(N))/N 
bins = np.cumsum(w) 
ancestors = np.digitize(u, bins) 
return ancestors 


13.2.4.3 Systematic resampling 


26 We can further reduce the variance by forcing all the samples from 7:_1 to be deterministically 


27 generated from a shared random source, u ~ Unif(0, 1), by computing 
, ote 1 u 
= — 13.30 
PEN 2 (13.30) 


31 We derive the corresponding ancestor indexes using Equation (13.29). See Listing 13.3 for some 
32 Python code. (The only difference from Listing 13.2 is the use of np.random.rand() instead of 


np.random.rand(N).) 


Listing 13.3: Systematic resampling 
def systematic_resampling(w): 
N = w.shape [0] 


u = (np.arange(N) + np.random.rand())/N 
bins = np.cumsum(w) 
ancestors = np.digitize(u, bins) 


return ancestors 


— 13.2.4.4 Comparison 


43 It can be proved that all of the above methods are unbiased. It can also be proved that stratified 


resampling is lower variance than multinomial resampling. Empirically it seems that systematic 
resampling is lower variance than other methods [HSG06]. A more complex resampling scheme, that 


46 is guaranteed to converge and which is also low variance, is described in [GCW19]. 
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13.2.5 Adaptive resampling 


The resampling step can result in loss of diversity, since each ancestor may generate multiple children, 
and some may generate no children, since the ancestor indices A} are sampled independently; this is 
the path degeneracy problem mentioned above. On the other hand, if we never resample, we end up 
with SIS, which suffers from weight degeneracy (particles with negligible weight). A compromise is 
to use adaptive resampling, in which we resample whenever the effective sample size or ESS 
drops below some minimum, such as N/2. A common way to define the ESS is as follows: 


ESS(WHY) = =N aces (13.31) 
pa (WY 
Alternatively we can compute the ESS using the unnormalized weights: 
N Tye 
~1:N (a bid ) 
ESS(w*") = (13.32) 


Note that if we have k weights with ù” = 1 and N — k weights with w” = 0, then the ESS is k; thus 
ESS is between 1 and N. 

The pseudocode for SISR with adaptive resampling is given in Algorithm 34. (We use the notation 
of [Law+22, App. B], in which we first sample new extensions of the sequences, and then optionally 
resample the sequences at the end of each step.) 


Algorithm 34: SISR with adaptive resampling (generic SMC) 


1 Initialization: EM =1, Z =1 
2 fort=1:T do 


3 for i = 1 : N, do 
4 Sample particle z} ~ q:(z:|Zj.4_1) 
: : i Fe (Zi.2) 
5 Compute incremental weight a; = Fahy Jae) 
6 Compute unnormalized weight w; = wj_,a; 
__——_ Ne wi $ és _—__. 
7 Estimate normalization constant: Z;/Z;-1 = SHON , Zt = Ly-1(Zt/Zt-1) 
i=1 “t-1 

8 if ESS(WEN) < ESSmin then 

9 Compute ancestors a;'%* = resample(w;'™*) 
10 Select z} ^ = permute(a}'™s, zF ^>) 
11 Reset unnormalized weights ©F™ = 1/N, 
12 | Compute normalized weights Wf = a for i=1: N, 

jwt 

13 | Compute MC posterior 74(2Z1:2) = E Wid(z1:2 — zia) 

1. Note that the ESS used in SMC is different than the ESS used in MCMC (Section 12.6.3); the latter takes into 


account auto-correlation of the MCMC samples. 
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13.3 Proposal distributions 


The efficiency of PF is crucially dependent on the quality of the proposal distribution. We discuss 
some options below. 


13.3.1 Locally optimal proposal 


We define the (one-step) locally optimal proposal distribution qj (z;|Z1.1.-1) to be the one that 
minimizes 


Dy (me-1 (214-1) Ge (2t|Z1:4-1) || 7e(Z1:2)) 13.33 


13.34 
13.35 
13.36 


= Ex,_1q: log {me—-1 (21:41) Ge (Ze|Z1:2-1) } — log me(21:)] 
= Er, 1q: [log ae(22|21:4-1) — log m(2:|21:4-1)] + const 


= Er, 1q [Dre (G¢(2t|21:4-1) || 7e(2t|21.4-1))] + const 


(13.33) 
(13.34) 
(13.35) 
(13.36) 


SIS IN S S IE IS IS IS IS Ie le IS IS la le le Is lE ls 


A Je je PR PR JR JR [BR Joo fw joo jw jw Jo jw jw jo joo 
IS J8 IÀ IR JE I IB I IS IB IS IS IS Ik 18 18 IF IS | 


The KL is minimized by choosing 


4 (214) 


¥1e(Z1-1) (13.37) 


q (2t|Z1:4-1) = Te (Z| 21-1) = 


where 44(21:-1) = f 44(21:4)dz is the probability of the past sequence under the current target 
distribution. 

Note that the subscript t specifies the tth distribution, so in the context of SSMs, we have 
we (Zt|Z1-4-1) = p(2t|21-4-1, Yr). Thus we see that when proposing z+, we should condition on all 
the data, including the most recent observation, yz; this is called a guided particle filter, and will 
will be better than the bootstrap filter, which proposes from the prior. 

In general, it is intractable to compute the locally optimal proposal, so we consider various 
approximations below. 


13.3.2 Proposals based on the extended and unscented Kalman filter 


“= One way to approximate the locally optimal proposal distribution is based on the extended Kalman 
= filter (Section 8.5.2) or the unscented Kalman filter (Section 13.3.2, which gives rise to the extended 
= particle filter [DGA00] and unscented particle filter [Mer+00] respectively. To explain these 
= methods, we follow the presentation of [NLS19, p36]. As usual, we assume the dynamical system 
= can be written as z; = f(z:-1) + q: and y+ = h(z:) +11, where q is the system noise and rz is 
— the observation noise. The EKF and UKF approximations assume that the joint distribution over 
— neighboring time steps, given the ith history, is Gaussian: 


(Zt, Yelzia-1) N (G) ai, $) (13.38) 


Ut 
43 where 
ai jut, si em $i 
H = (N:S = {i a (13.39) 
Hy ($ Zyy 
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(See Section 8.7 for details.) 

The EKF and UKF compute jz’ and >»! differently. In the EKF, we linearize f and h, and assume 
the noise terms are Gaussian. We then compute p(z, y:|Z}.4_1) exactly for this linearized model 
(see Section 8.5.1). In the UKF, we propagate sigma points through f and h, and approximate the 
resulting means and covariances using the unscented transform, which can be more accurate (see 
Section 8.6). Once we have computed jz’ and È$, we can use standard rules for Gaussian conditioning 
to compute the approximate proposal as follows: 


qlzilzi-1 ye) ~N (zilai £i) (13.40) 
ui = Ai + Ei (Ei) (ye — BY) (13.41) 


Note that the linearization (or sigma point) approximation needs to be performed for each particle 
sepatately. 


13.3.3 Proposals based on the Laplace approximation 


To handle non-Gaussian likelihoods in an SSM, we can use the Laplace approximation (Section 7.4.3), 
as suggested in [DGA00]. In particular, consider an SSM with linear-Gaussian latent dynamics and a 
GLM likelihood. At each step, we compute the maximum z¥ = argmax log p(y;|z;) as step t (e.g., 
using Newton-Raphson), and then approximate the likelihood using 

pyi) ~ N (zilz¥, —H%) (13.43) 


where Hž is the Hessian of the log-likelihood at the mode. We now compute p(z;|z/_,, Y+) using the 
update step of the Kalman filter, using the same equations as in Section 13.3.2. This combination is 
called the the Laplace Gaussian filter [Koy+10]. We give an example in Section 13.3.3.1. 


13.3.3.1 Example: neural decoding 


In this section, we give an example where we apply the Laplace approximation to an SSM with 
linear-Gaussian dynamics and a Poisson likelihood. The application arises from neuroscience. In 
particular, assume we record the neural spike trains as a monkey moves its hand around in space. 
Let z; € RÊ represent the 3d location and velocity of the hand. We model the dynamics of the hand 
using a simple Brownian random walk model [CP20b, p157]: 


Gan) lta ~ Na G a as) .0°Q) ,t=1:3 (13.44) 


where the covariance of the noise is given by the following, assuming a discretization step of A: 


Q= tee ae *) (13.45) 


We assume the k’th observation at time t is the number of spikes for neuron k in this sensing 
interval: 


P(yt(k)|24) = Poi(An(2z)) (13.46) 
log An (Zt) = ar + Bz (13.47) 
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Figure 13.5: Effective sample size at each step for the bootstrap particle filter and a guided particle fil- 
ter for a Gaussian SSM with Poisson likelihood. Adapted from Figure 10.4 of [CP20b]. Generated by 
pf _guided_neural_ decoding.ipynb. 


Our goal is to compute p(z;|y1-4), which lets us infer the position of the hand from the neural code. 
(Apart from its value for furthering basic science, this can be useful for applications such as helping 
disabled people control their arms using “mind control”.) 

To illustrate this, we sample a synthetic dataset from the model, to simulate a “monkey” moving 
its arm for T = 25 time steps; this generates K = 50 neuronal counts per time step. We then apply 
particle filtering to this dataset (using the true model), using either the bootstrap filter (i.e., proposal 
is the random walk prior) or the guided filter (i.e., proposal is the Laplace approximation mentioned 
above). In Figure 13.5, we see that the effective sample size of the guided filter is much higher than 
for the bootstrap filter. 


13.3.4 Proposals based on SMC (nested SMC) 


25 It is possible to use SMC as a subroutine to compute a proposal distribution for SMC: at each step t, 


for each particle i, we run an SMC algorithm where the target distribution is the optimal proposal, 
p(z|24.4-1,Yit). This is called nested SMC [NLS15; NLS19]. 

This method can approximate the locally optimal proposal arbitrarily well, since it does not make 
any limiting parametric assumptions. However, the method can be slow, although the inner SMC 
algorithm can be run in parallel for each outer sample [NLS15; NLS19]. 


13.4 Rao-Blackwellised particle filtering (RBPF) 


41 In some models, we can partition the hidden variables into two kinds, m, and z+, such that we can 
42 analytically integrate out z; provided we know the values of m ,.;. This means we only have to sample 
43 Mı, and can represent p(z;|™m1.z, yi) parametrically. These hybrid particles are sometimes called 


distributional particles or collapsed particles [KF 09a, Sec 12.4]. This combines techniques from 


45 particle filtering (Section 13.2) with deterministic methods such as Kalman filtering (Section 8.3.2). 


The advantage of this approach is that we reduce the dimensionality of the space in which we are 
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13.4. RAO-BLACKWELLISED PARTICLE FILTERING (RBPF) 


sampling, which reduces the variance of our estimate. This technique is known as Rao-Blackwellised 
particle filtering or RBPF for short. (See Section 11.6.2 for more details on Rao-Blackwellisation.) 
In Section 13.4.1 we give an example of RBPF for inference in a switching linear dynamical systems. 
In Section 13.4.3 we illustrate RBPF for inference in the SLAM model for a mobile robot. 


13.4.1 Mixture of Kalman filters 


In this section, we consider the application of RBPF to a switching linear dynamical system 
(Section 29.9). This model has both continuous and discrete latent variables. This can be used 
to track a system that switches between discrete modes or operating regimes, represented by the 
discrete variable m,. 

For notational simplicity, we ignore the control inputs us. Thus the model is given by 


P(Zt|Zt-1, m = k) = N (2 |F 21-1, Qe) (13.48) 
P(ylZe, mi = k) = N (y|Hrzt, Re) (13.49) 
p(mz = klme-1 = j) = Ajk (13.50) 


We let O; = (Fr, Hx, Qk, Re, A. k) represent all the parameters for state k. 

Exact inference is intractable, but if we sample the discrete variables, we can infer the continuous 
variables conditoned on the discretes exactly, making this a good candidate for RBPF. In particular, 
if we sample trajectories m7., we can apply a Kalman filter to each particle. This can be thought of 


as a mixture of Kalman filters [CL00]. The resulting belief state is represented by 


N 
PZ, mailyr) ~ X WSM, — mp )N (2elme, ZY) (13.51) 


n=1 


To derive the filtering algorithm, note that the full posterior at time t can be written as follows: 


P(M14, Z14\Y1-t) = PZM, Yi-4)P(™M14|Y1-4) (13.52) 


The second term is given by the following: 


P(Mi:tlYr:t) X P(Y M4, Yrsz—1)P(M14|Y14-1) (13.53) 
= P(Y Mit Yi:t—1)P(mMe| M1 4-1, Yi:t—1)P(M14-1|Yr:t—-1) (13.54) 
= p(yr| M1, Yrt-1)P(Mi|Mi—1) P(M14-1|Yr4-1) (13.55) 


Note that, unlike the case of standard particle filtering, we cannot write p(yz|m14, y1-4—-1) = p(ytz|™M2), 
since m, does not d-separate the past observations from y;, as is evident from Figure 29.25a. 
Suppose we use the following recursive proposal distribution: 


ami 4|Yrt) = (Mmi |Mi:t—1, Yrs (M11 Yr-4) (13.56) 
Then we get the unnormalized importance weights 


me, my 4, Yri) Pmp |m) - 
iP x P(yelm aE 1 = Dp(me|mi Dar, (13.57) 
q(ME |M t1: Yit) 
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As a special case, suppose we propose from the prior, g(m:|m?_1, y1:4) = p(me|mP_,). If we sample 
discrete state k, the weight update becomes 


wy x w p(yilme = k, M444, Yrt—-1) = Wey Lh, (13.58) 


where 
Li, = P(yelme = k, Moy, Y1:t-1) = [vlad =k, z1)p(2t|me = k, yrt—1, Mi 4_1)dz_ (13.59) 


The quantity Li, is the predictive density for the new observation y; conditioned on m = k and 
the history of previous latents, m7.,_,. In the case of SLDS models, this can be computed using 
the normalization constant of the Kalman filter, Equation (8.80). The resulting algorithm is shown 
in Algorithm 35. The step marked “KFupdate” refers to the Kalman filter update equations in 
Section 8.3.2, and is applied to each particle separately. 


Algorithm 35: One step of RBPF for SLDS using prior as proposal 


1 forn=1:Ndo 

2 | ko~ p(m|me_) 

3 me := k 

4 (u, Dr, Lip) = KFupdate( ut 1, X41, Yt, Ox) 

5 | we = wi Lh, 

Compute ESS = ESS (t^s) 

if ESS < ESSmin then 

at = Resample(Ŭt™ ) 

(m^s , jie, =) = permute(a+, my ™ , pe, BEN) 
10 wr = 1/Ns 


o aN OD 


13.4.1.1 Improvements 


An improved version of the algorithm can be developed based on the fact that we are sampling a 


= discrete state space. At each step, we propagate each of the N old particles through all K possible 


transition models. We then compute the weight for all NK new particles, and sample from this to 
get the final set of N particles. This latter step can be done using the optimal resampling method 


= of [FC03], which will stochastically select the particles with the largest weight, while also ensuring 


the result is an unbiased approximation. In addition, this approach ensures that we do not have 
duplicate particles, which is wasteful and unnecessary when the state space is discrete. 


13.4.2 Example: tracking a maneuvering object 


43 In this section we give an example of RBPF for an SLDS from [DGKO1]. Our goal is to track an 


object that has the following motion model: 
p(Zt|Z4-1, me = k) = N (zE 21-1 + bk, Q) (13.60) 
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Data RBPF MSE: 7.66 Bootstrap Filter MSE: 10.95 


20 Lee ey . 20 A wa aptat 


(a) (b) (c) 


Figure 13.6: Illustration of state estimation for a switching linear model. (a) Black dots are observations, 
hollow circles are the true location, colors represent the discrete state. (b) Estimate from RBPF. Generated 
by rbpf_maneuver.ipynb. (c) Estimate from bootstrap filter. Generated by bootstrap_filter_maneuver.ipynb. 


where zi = (Zit, £14, £24, £24) contains the 2d position and velocity. We define the observaton matrix 
by H = I and the observation covariance by R = 10 diag(2, 1,2,1). We define the dynamics matrix 
by 


(13.61) 


oorp 
o.oo 


0 
0 
A 
1 


ooo eK 


where A = 0.1,. We set the the noise covariance to Q = 0.21 and the input bias vectors for each state 
to bı = (0,0,0,0), b2 = (—1.225, —0.35, 1.225, 0.35) and b3 = (1.225, 0.35, —1.225, —0.35). Thus the 
system will turn in different directions depending on the discrete state. The discrete state transition 
matrix is given by 


0.8 0.1 0.1 
A=[01 08 01 (13.62) 
0.1 0.1 0.8 


Figure 13.6a shows some observations, and the true state of the system, from a sample run, for 
100 steps. The colors denote the discrete state, and the location of the symbol denotes the (x, y) 
location. The small dots represent noisy observations. Figure 13.6b shows the estimate of the state 
computed using RBPF with the optimal proposal with 1000 particles. In Figure 13.6c, we show the 
analogous estimate using the boostrap filter, which does much worse. 

In Figure 13.7a and Figure 13.7b, we show the posterior marginals of the (x, y) locations over time. 
In Figure 13.7c we show the true discrete state, and in Figure 13.7d we show the posterior marginal 
over discrete states. The overall state classification error rate is 29%, but it seems that occasionally 
misclassifying isolated time steps does not significantly hurt estimation of the continuous states, as 
we can see from Figure 13.6b. 
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(c) (d) 


22 Figure 13.7: Visualizing the posterior from the RBPF algorithm. Top row: Posterior marginals of the location 
33 of the object over time, derived from the mizture of Gaussian representation for (a) x location (dimension 
34 0), (b) y location (dimension 2). Bottom row: visualization of the true (c) and predicted (d) discrete states. 
35 Generated by rbpf_maneuver.ipynb. 


= 13.4.3 Example: FastSLAM 


40 Consider a robot moving around an environment, such as a maze or indoor office environment. It 
41 needs to learn a map of the environment, and keep track of its location (pose) within that map. 
42 This problem is known as simultaneous localization and mapping, or SLAM for short. SLAM 
43 is widely used in mobile robotics (see e.g., [SC86; CN01; TBF06] for details). It is also useful in 
44 augmented reality, where the task is to recursively estimate the 3d pose of a handheld camera with 
45 respect to a set of 2d visual landmarks (this is known as visual SLAM, [TUIL7; SMT18; Cza+20; 
46 DH22)). 
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rQ hou 4 i 


Q-0—-Q— 0 
ae 3 


\ : 
[oh gd 


(a) (b) 


Figure 13.8: Graphical model representing the SLAM problem. l? is the location of landmark k at time 
t, rı is the location of the robot at time t, and y+ is the observation vector. In the model on the left, the 
landmarks are static (so they act like global shared parameters), on the right, their location can change over 
time. The robot’s observations are based on the distance to the nearest landmarks from the current state, 
denoted f (re, Uf). The number of observations per time step is variable, depending on how many landmarks 
are within the range of the sensor. Adapted from Figure 15.A.8 of [KF09a]. 


Let us assume we can represent the map as the 2d locations of a set of K landmarks, denote them 
by l',...,1* (each is a vector in R?). (We can use data association to figure out which landmark 
generated each observation, as discussed in Section 29.9.3.2.) Let r; represent the unknown location 
of the robot at time t. Let z; = (r:,1}'*) be the combined state space. We can then perform online 
inference so that the robot can update its estimate of its own location, and the landmark locations. 

The state transition model is defined as 


K 


P(Zel2t—1, Ue) = p(relre—1, GT, Ue) [[ eqn.) (13.63) 
k=1 


where p(ri|r:+—1, LEE , uz) specifies how the robot moves given the control signal u; and the location 
of the obstacles I/'4. (Note that in this section, we assume that a human is joysticking the robot 
through the environment, so u1: is given as input, i.e., we do not address the decision-theoretic issue 
of choosing where to move.) 

If the obstacles (landmarks) are static, we can define p(U*|l?_,) = 6(U% — Ik_,), which is equivalent 
to treating the map as an unknown parameter that is shared globally across all time steps. More 
generally, we can let the landmark locations evolve over time [Mur00]. 

The observations y+ measure the distance from r to the set of closest landmarks. Figure 13.8 
shows the corresponding graphical model for the case where K = 2, and where on the first step it 
sees landmarks 1 and 2, then just landmark 2, then just landmark 1, etc. 

If all the CPDs are linear-Gaussian, then we can use a Kalman filter to maintain our belief state 
about the location of the robot and the location of the landmarks, p(zt|Y1:t, wi). In the more 
general case of a nonlinear model, we can use the EKF (Section 8.5.2) or UKF (Section 8.6.2). 
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Robot pose 


(a) (b) 


Figure 13.9: Illustration of the SLAM problem. (a) A robot starts at the top left and moves clockwise in 
a circle back to where it started. We see how the posterior uncertainty about the robot’s location increases 
and then decreases as it returns to a familar location, closing the loop. If we performed smoothing, this 
new information would propagate backwards in time to disambiguate the entire trajectory. (b) We show the 
precision matriz, representing sparse correlations between the landmarks, and between the landmarks and 
the robot’s position (pose). The conditional independencies encoded by the sparse precision matrix can be 
visualized as a Gaussian graphical model, as shown on the right. From Figure 15.A.3 of [KF 09a]. Used with 
kind permission of Daphne Koller. 


Over time, the uncertainty in the robot’s location will increase, due to wheel slippage etc., but when 
the robot returns to a familiar location, its uncertainty will decrease again. This is called closing 
the loop, and is illustrated in Figure 13.9(a), where we see the uncertainty ellipses, representing 


31 Cov [z¢|y1-z, U1], grow and then shrink. 


In addition to visualizing the uncertainty of the robot’s location, we can visualize the uncertainty 
about the map. To do this, consider the posterior precision matrix, Ay = X; 1. Zeros in the precision 
matrix correspond to absent edges in the corresponding undirected Gaussian graphical model (GGM, 
see Section 4.3.5). Initially all the beliefs about landmark locations are uncorrelated (by assumption), 
so the GGM is a disconnected graph, and A; is diagonal. However, as the robot moves about, it will 
induce correlation between nearby landmarks. Intuitively this is because the robot is estimating its 
position based on distance to the landmarks, but the landmarks’ locations are being estimated based 
on the robot’s position, so they all become interdependent. This can be seen more clearly from the 
graphical model in Figure 13.8: it is clear that l! and I? are not d-separated by y1:+, because there is 


41 a path between them via the unknown sequence of r;., nodes. Consequently, the precision matrix 


becomes denser over time. As a consequence of the precision matrix becoming denser, each inference 


43 step takes O(K?) time. This prevents the method from being applied to large maps. 


One way to speed this up is based on the following observation: conditional on knowing the robot’s 
path, r1.;, the landmark locations are independent, i.e., p(li|ri4, y1-4) = I% pIE Irit, yi). This 


46 can be seen by looking at the DGM in Figure 13.8. We can therefore sample the trajectory using 
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13.5. EXTENSIONS OF THE PARTICLE FILTER 


some proposal, and apply (2d) Kalman filtering to each landmark independently. This is an example 
of RBPF, and reduces the inference cost to O( NK), where N is the number of particles and K is 
the number of landmarks. 

The overall cost of this technique is O(N K) per step. Fortunately, the number of particles N 
needed for good performance is quite small, so the algorithm is essentially linear in the number of 
landmarks, making it quite scalable. This idea was first suggested in [Mur00], who applied it to 
grid-structured occupancy grids (and used the HMM filter for each particle). It was subsequently 
extended to landmark-based maps in [Thr-+04], using the Kalman filter for each particle; they called 
the technique FastSLAM. 


13.5 Extensions of the particle filter 
There are many extensions to the basic particle filtering algorithm, such as the following: 


e We can increase particle diversity by applying one or more steps of MCMC sampling (Section 12.2) 
at each PF step using 7;(z,) as the target distribution. This is called the resample-move 
algorithm [DJ11]. It is also possible to use SMC instead of MCMC to diversify the samples 
[GM17]. 


e We can extend PF to the case of offline inference; this is called particle smoothing (see e.g., 
[Kla+06]). 


e We can extend PF to inference in general graphical models (not just chains) by combining PF 
with loopy belief propagation (Section 9.3); this is called non-parametric BP or particle BP 
(see e.g., [Sud+03; Isa03; Sud+10; Pac+14]). 


e We can extend PF to perform inference in static models (e.g., for parameter inference), as we 
discuss in Section 13.6. 


13.6 SMC samplers 


In this section, we discuss SMC samplers (sequential Monte Carlo samplers), which are a way 
to apply particle filters to sample from a generic target distribution, 7(z) = ¥(z)/Z, rather than 
requiring the model to be an SSM. Thus SMC is an alternative to MCMC. 

The advantages of SMC samplers over MCMC are as follows: we can estimate the normalizing 
constant Z; we can more easily develop adaptive versions that tune the transition kernel using the 
current set of samples; and the method is easier to parallelize (see e.g., [CCS22; Gre+22]). 

The method works by defining a sequence of intermediate distributions, 7;(z,), which we expand 
to a sequence of distributions over all the past variables, 7,(z1:4). We then use the particle filtering 
algorithm to sample from each of these intermediate distributions. By marginalizing all but the final 
state, we recover samples from the target distribution, 7(z) = >> Tr(zı:r), as we explain below. 
(For more details, see e.g., [Dai+20a; CP20b].) 


Z1:T-1 


13.6.1 Ingredients of an SMC sampler 


To define an SMC sampler, we need to specify several ingredients: 
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e A sequence of distributions defined on the same state space, m:(Z:) = Jel zt)/Za, for t = 0 : T; 


e A forwards kernel M; (z:|z:—1) (often written as Mi(zt—1, zt)), which satisfies DA My (z¢|Z:-1) = 
1. This can be used to propose new samples from our current estimate when we apply particle 
filtering. 


e A backwards kernel L;(z;|2:+1) (often written as D(z, zt+1)), which satisfies Se Le (2/2141) = 
1. This allows us to create a sequence of variables by working backwards in time from the final 
target value to the first time step. In particular, we create the following joint distribution: 


t-1 


7(Z1:) = Telz) Il Ls(2s|Zs41) (13.64) 


s= 


This satisfies aren 71:(Z1-4) = miz), so if we apply particle filtering to this for t = 1 : T, then 
samples from the “end” of such sequences will be from the target distribution m+. 


With the above ingredients, we can compute the incremental weight at step t using 


Tr (Z1:t) m Hl) Le-1 (2-122) 
Me—1 (214-1) Mi (2t|Zt-1)  Te-1(2t-1) Me (24/24-1) 


(13.65) 


Qt = 


This can be plugged into the generic SMC algorithm, Algorithm 34. 

We still have to specify the forwards and backwards kernels. We will assume the forwards kernel 
M is an MCMC kernel that leaves 7; invariant. We can then define the backwards kernel to be the 
time reversal of the forwards kernel. More precisely, suppose we define L—1 so it satisfies 


we (Ze) Le_1 (Z-1|2t) = me (Ze—-1) M4 (Z| 24-1) (13.66) 
In this case, the incremental weight simplifies as follows: 


Zini Zt) Le—-1 (24-12) 
Q = 13.67 
tO Zea e_1 (24-1) Mi (224-1) ( ) 


Zimı(zt—1)Mi(ztlzt—1) 


_ 13.68 

Zt—1Mt—1(Zt—-1) Mt (Z| 22-1) ( ) 
Fe (Ze-1) 

— elz) 13.69 

Ye-1(Zt-1) ( ) 


We can use any kind of MCMC kernel for M+. For example, if the parameters are real valued and 


41 unconstrained, we can use a Markov kernel that corresponds to K steps of a random walk Metropolis- 
42 Hastings sampler. We can set the covariance of the proposal to 67%1;_1, where X41 is the empirical 
43 covariance of the weighted samples from the previous step, (WEY, z/V), and 6 = 2.38D~3/? (which 


is the optimal scaling parameter for RWMH). In high dimensional problems, we can use gradient 


45 based Markov kernels, such as HMC [BCJ20] and NUTS [Dev+21]. For binary state spaces, we can 
46 use the method of [SC13]. 
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13.6. SMC SAMPLERS 


13.6.2 Likelihood tempering (geometric path) 


There are many ways to specify the intermediate target distributions. In the geometric path 
method, we specify the intermediate distributions to be 


H(z) = Yo(z)* **4¥(2)™ (13.70) 


where 0 = ào < Ay <--: < Àr = 1 are inverse temperature parameters, and o is the initial 
proposal. If we apply particle filtering to this model, but “turn off” the resampling step, the method 
becomes equivalent to annealed importance sampling (Section 11.5.4). 

In the context of Bayesian parameter inference, we often denote the latent variable z by 0, we 
define 7o(@) x To(0) as the prior, and 4(z) = ro(0)p(D|0) as the posterior. We can then define the 
intermediate distributions to be 


lO) = 10(0)1~**19(0)**p(DIO)** = 10 (0)1~** exp[—ALE(8)] (13.71) 
where €(@) = — log p(D, 0) is the energy (potential) function. The incremental weights are given by 


m™(0)1—** exp[—A,E(0)| 
m9(0)1—* exp[—Ar_1E (8)| 


where Ay = Az~1 + Ot. 

For this method to work well, it is important to to choose the A; so that the successive distributions 
are “equidistant”; this is called adaptive tempering. In the case of a Gaussian prior and Gaussian 
energy, one can show [CP20b] that this can be achieved by picking A, = (1+ y)‘t! — 1, where y > 0 
is some constant. Thus we should increase ÀA slowly at first, and then make bigger and bigger steps. 

In practice we can estimate A, by setting Ay = Ax-1 + ôs , where 


az (6) = 


= exp[—46,€(8)] (13.72) 


ô= argmin (ESSLW({—6 €(0”)}) — ESSmin) (13.73) 
5€[0,1—Az_1] 


where ESSLW({l,}) = ESS({e'"}) computes the ESS (Equation (13.32)) from the log weights, 
ln = log ù”. This ensures the change in the ESS across steps is close to the desired minimum ESS, 


typically 0.5.N. (If there is no solution for 6 in the interval, we set 6, = 1—A4-1.) See Algorithm 36 
for the overall algorithm. 


13.6.2.1 Example: sampling from a 1d bimodal distribution 


Consider the simple distribution 
p(O) x N(6|0, I) exp(—E(@)) (13.74) 


where €(@) = c(||@||? — 1)?. We plot this in 1d in Figure 13.10a for c = 5; we see that it has a 
bimodal shape, since the low energy states correspond to parameter vectors whose norm is close to 1. 

SMC is particularly useful for sampling from multimodal distributions, which can be provably 
hard to efficiently sample from using other methods, including HMC [MPS18], since gradients only 
provide local information about the curvature. As an example, in Figure 13.11a and Figure 13.11b 
we show the result of applying HMC (Section 12.5) and NUTS (Section 12.5.4.1) to this problem. 
We see that both algorithms get stuck near the initial state of 0) = 1. 
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Figure 13.10: (a) Illustration of a bimodal target distribution. (b) Tempered versions of the target at different 
inverse temperatures, from Ar = 1 down to A; = 0. Generated by smc_tempered_1d_ bimodal.ipynb. 
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42 Figure 13.11: Sampling from the bimodal distribution in Figure 13.10a. (a) HMC. (b) NUTS. (c) Tem- 
43 pered SMC with HMC kernel (single step). (d) Adaptive inverse temperature schedule. 
44 smc_tempered_1d_bimodal.ipynb. 


Generated by 
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13.6. SMC SAMPLERS 


Algorithm 36: SMC with adaptive tempering 
1A1.=0,t=-1,W",=1 
2 while à < 1 do 
3 t=t+1 
4 if t = 0 then 
5 | OG ~ To(0) 
6 else 
7 AEN = Resample(W EX ) 
| 07 ~ Mx (0n) 
9 Compute 6; using Equation (13.73) 
10 At = At—-1 + Ot 
11 wr = exp|—- ôE (07 )] 


n Tb N atte 
12 | Wf = ù; [X m= WP") 


In Figure 13.10b, we show tempered versions of the target distribution at 5 different temperatures, 
chosen uniformly in the interval [0,1]. We see that at àı = 0, the tempered target is equal to the 
Gaussian prior (blue line), which is easy to sample from. Each subsequent distribution is close to the 
previous one, so SMC can track the change until it ends up at the target distribution with Ar = 1, 
as shown in Figure 13.11c. 

These SMC results were obtained using the adaptive tempering scheme described above. In 
Figure 13.11d we see that initially the temperature is small, and then it increases exponentially. The 
algorithm takes 8 steps until Ap > 1. 


13.6.3 Data tempering 


If we have a set of iid observations, we can define the tth target to be 
W0) = p(O)p(yi|) (13.75) 


We can now apply SMC to this model. From Equation (13.69), the incremental weight becomes 


a,(8) — Jal zı—1) 0)p(y1:+|0) 


M1 (24-1) ~ ne = plyelyie—1, 8) (13.76) 


This can be plugged into the generic SMC algorithm in Algorithm 34. 

Unfortunately, to sample from the MCMC kernel will typically take O(t) time, since the MH 

z i ; 7 . 1 t Jo! á 1 

accept/reject step requires computing p(6’) [];_, p(y1:|9’) for any proposed 0’. Hence the total 
cost is O(T?) if there are T observations. To reduce this, we can only sample parameters at times 
t when the ESS drops below a certain level; in the remaining steps, we just grow the sequence 
deterministically by repeating the previously sampled value. This technique was proposed in [Ch002], 
who called it the iterated batch importance sampling or IBIS algorithm. 
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Figure 13.12: Illustration of IBIS applied to 30 samples from N (u = 3.14,0 = 1). (a) Posterior approximation 
after t = 1 and t = 29 observations. (b) Effective sample size over time. The sudden jumps up occur whenever 
resampling is triggered, which happens when the ESS drops below 500. Generated by smc_ ibis 1d.ipynb. 


13.6.3.1 Example: IBIS for a 1d Gaussian 


In this section, we give a simple example of IBIS applied to data from a 1d Gaussian, y ~ N (u = 
3.14,0 = 1) for t = 1: 30. The unknowns are 0 = (1,0). The prior is p(@) = N(u|0,1)Ga(ola = 
1,b=1). We use IBIS with an adaptive RWMH kernel. We use N = 20 particles, each updated for 
K = 50 MCMC steps, so we collect 1000 samples per time step. 

Figure 13.12a shows the approximate posterior after t = 1 and t = 29 time steps. We see that the 
posterior concentrates on the true values of pp = 3.14 and o = 1. 

Figure 13.12b plots the ESS vs time. The number of particles is 1000, and resampling (and MCMC 


26 moves) is triggered whenever this drops below 500. We see that we only need to invoke MCMC 
2° updates 3 times. 


= 13.6.4 Sampling rare events and extrema 


31 Suppose we want to sample values from 79(@) conditioned on the event that S(@) > A*, where S 


is some score or “fitness” function. If A* is in the tail of the score distribution, this corresponds to 


33 sampling a rare event, which can be hard. 


One approach is to use SMC to sample from a sequence of distributions with gradually increasing 
thresholds: 


(0) = I (S(@) > Az) To(0) (13.77) 


with Ag < +--+ < Ap = à*. We can then use likelihood tempering, where the “likelihood” is the 
function 


G(61) = 1(S(6:) = `i) (13.78) 


We can use SMC to generate samples from the final distribution mr. We may also be interested in 
estimating 


Zr = p(S(@) > Ar) (13.79) 
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13.6. SMC SAMPLERS 


where the probability is taken wrt 7(@). 

We can adaptively set the thresholds A+ as follows: at each step, sort the samples by their score, 
and set A; to the a’th highest quantile. For example, if we set a = 0.5, we keep the top 50% fittest 
particles. This ensures the ESS equals the minimum threshold at each step. For details, see [Cér+12]. 

Note that this method is very similar to the cross-entropy method (Section 6.9.5). The difference 
is that CEM fits a parametric distribution (e.g., a Gaussian) to the particles at each step and samples 
from that, rather than using a Markov kernel. 


13.6.5 SMC-ABC and likelihood-free inference 


The term likelihood-free inference refers to estimating the parameters @ of a black box from 
which we can sample data, y ~ p(-|@), but where we cannot evaluate p(y|@) pointwise. Such models 
are called simulators, so this approach to inference is also called simulation-based inference (see 
e.g., [Nea+08; CBL20; Gou+96]). These models are also called implicit models (see Section 26.1). 

If we want to approximate the posterior of a model with no known likelihood, we can use 
Approximate Bayesian Computation or ABC (see e.g., [Beal9; SFB18; Gut+14; Pes+21]). 
In this setting, we sample both parameters 0 and synthetic data y such that the synthetic data 
(generated from @) is sufficiently close to the observed data y*, as judged by some distance score, 
d(y,y*) < e. (For high dimensional problems, we typically require d(s(y), s(y*)) < €, where s(y) is 
a low-dimensionary summary statistic of the data.) 

In SMC-ABC, we gradually decrease the discrepancy e to get a series of distributions as follows: 


m(0, y) = 7 mo(0)p(ulo)1 (d(y,y*) < €) (13.80) 


t 


where €9 > €; >---. This is similar to the rare event SMC samplers in Section 13.6.4, except that 
we can’t directly evaluate the quality of a candidate 0, instead we must first convert it to data space 
and make the comparison there. For details, see [DMDJ12]. 

Although SMC-ABC is popular in some fields, such as genetics and epidemiology, this method 
is quite slow and does not scale to high dimensional problems. In such settings, a more efficient 
approach is to train a generative model to emulate the simulator; if this model is parametric with a 
tractable likelihood (e.g., a flow model), we can use the usual methods for posterior inference of its 
parameters (including gradient based methods like HMC). See e.g., [Bre+20a] for details. 


13.6.6 SMC? 


We have seen how SMC can be a useful alternative to MCMC. However it requires that we can 
ye (Or) 
ye-1 (G4) * 7 
latent variable models), we can use SMC (specifically the estimate Z, in Equation (13.10)) as a 
subroutine to approximate these likelihoods. This is called SMC?. For details, see [CP20b, Ch. 18]. 


efficiently evaluate the likelihood ratio terms 


In cases where this is not possible (e.g., for 


13.6.7 Variational filtering SMC 


One way to improve SMC is to learn a proposal distribution (e.g., using a neural network) such that 
the approximate posterior, 77(21:7; Q, 9), is close to the target posterior, nr(zı:r;0), where 8 are 
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the model parameters, and @ are the proposal parameters (which may depend on 0). One can show 
[Nae+18] that the KL divergence between these distributions can be bounded as follows: 


0 < Dri (E[*r(21-r)] || mr(z1:r)) < —E |log a (13.81) 
where 

Zr (8) = po(yı:r) = [volves mere (13.82) 
Hence 

slog Zr(@, 4)| < E [log Zr(0)] = log Zr(@) (13.83) 


Thus we can use SMC sampling to compute an unbiased approximation to E [log Zr(0, )| , which is 


a lower bound on the evidence (log marginal likelihood). 

We can now maximize this lower bound wrt @ and @ using SGD, as a way to learn both proposals 
and the model. Unfortunately, computing the gradient of the bound is tricky, since the resampling 
step is non-differentiable. However, in practice one can ignore the dependence of the resampling 
operator on the parameters, or one can use differentiable approximations (see e.g., [Ros+22]). This 
overall approach was independently proposed in several papers: the FIVO (filtering variational 
objective) paper [Mad+17], the variational SMC paper [Nae+18] and the auto-encoding SMC 
paper [Le+18]. 


27 13.6.8 Variational smoothing SMC 


The methods in Section 13.6.7 use SMC in which the target distributions are defined to be the 
filtered distributions, 7(z1:4) = pe (Z1:t|Y1:+); this is called filtering SMC. Unfortunately, this can 


31 work poorly when fitting models to offline sequence data, since at time ¢, all future observations are 


ignored in the objective, no matter how good the proposal. This can create situations where future 
observations are unlikely given the current set of sampled trajectories, which can result in particle 
impoverishment and high variance in the estimate of the lower bound. 

Recently, a new method called SIXO (smoothing inference with twisted objectives) was proposed 
in [Law+22] that uses the smoothing distributions as targets, (21+) = pe(214|Yy1:r), to create a 
much lower variance variational lower bound. Of course it is impossible to directly compute this 
posterior, but we can approximate it using twisted particle filters [WL14a; AL+16]. In this 
approach, we approximate the (unnormalized) posterior using 


Po(Z1:, YLT) = Po(Z1:t, Y1:t)Po(Yrti:7|Z1:t, Y1t) (13.84) 
= po (Z1:t, Y1:t)Po(Yt+1:T|Zt) (13.85) 
x Po(Z1:t5 Vit) Tp (Yt+1:T, 2t) (13.86) 


45 where ry(Yt+1:T, Zt) © Po(yi+1:7|2z) is the twisting function, which acts as a “lookahead func- 
46 tion”. 
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One way to approximate the twisting function is to note that 


pe(Ze\Ye+1:7)Pe(Yt41:7) ie po(2+|Ye+1:7) 


ren AA en 


pe(Yr4i:7|2t) = 


where we drop terms that are independent of z; since such terms will cancel out when we normalize 
the sampling weights. We can approximate the density ratio using the binary classifier method of 


Section 2.7.5. To do this, we define one distribution to be pı = pe (Zt, yz41:r) and the other to be 
Po = Po(2+)Pe(Yesi-7); 80 that pi /py = P2@¥e4uT) We can easily draw a sample (21:7, yer) ~ Po 
using ancestral sampling, from which we can compute (z+, Y¥:41:r) ~ pı by marginalization. We can 
also sample a fresh sequence from (21:7, 91:7) ~ pe from which we can compute (Ži, Yi41:7) ~ P2 
by marginalization. We then use (z4, yz+1:7) as a positive example and (Žž, 9:41:17) as a negative 
example when training the binary classifier, ry (Yt+1:T, Zt). 

Once we have updated the twisting parameters y, we can rerun SMC to get a tighter lower bound 
on the log marginal likelihood, which we can then optimize wrt the model parameters 0 and proposal 
parameters @. Thus the overall method is a stochastic variational EM-like method for optimziing 


the bound 


Lstxo(9, Q, Y, Yr) = 5 [log Zs1xo0 (9, dQ, p, y1:7)| (13.88) 


< log E [Żsixo(0, $, 4, y1:7)| = log po(yı:T) (13.89) 


In [Law+22] they prove the following: suppose the true model p* is an SSM in which the optimal 
proposal function for the model satisfies p*(z|21-4-1, y1:7) E€ Q, and the optimal lookahead function 
for the model satisfies p*(y:41-r|Z1) E€ R. Furthermore, assume the SIXO objective has a unique 
maximizer. Then, at the optimum, we have that the learned proposal q¢~(Z:|214-1, y1:7) € Q is 
equal to the optimal proposal, the learned twisting function ry«(yr+1.7, Zt) E€ R is equal to the 
optimal lookahead, and the lower bound is tight (i-e., Lsrxo(0*, 6", Y“) = p*(yı:r)) for any number 
of samples N, > 1 and for any kind of SSM p*. (This is in contrast to the FIVO bound, whiere the 
bound does not usually become tight.) 
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PART III 


Prediction 


1 4 Predictive models: an overview 


14.1 Introduction 


The vast majority of machine learning is concerned with tackling a single problem, namely learning 
to predict outputs y from inputs x using some function f that is estimated from a labeled training 
set D = { (£n, Yn) : n = 1 : N}, for £n E€ X C RP and yn € Y C R©. We can model our uncertainty 
about the correct output for a given input using a conditional probability model of the form p(y|f(ax)). 
When Y is a discrete set of labels, this is called (in the ML literature) a discriminative model, 
since it lets us discriminate (distinguish) between the different possible values of y. If the output 
is real-valued, Y = R, this is called a regression model. (In the statistics literature, the term 
“regression model” is used in both cases, even if Y is a discrete set.) We will use the more generic 
term “predictive model” to refer to such models. 

A predictive model can be considered as a special case of a conditional generative model (discussed 
in Chapter 20). In a predictive model, the output is usually low dimensional, and there is a single 
best answer that we want to predict. However, in most generative models, the output is usually high 
dimensional, such as images or sentences, and there may be many correct outputs for any given input. 
We will discuss a variety of types of predictive model in Section 14.1.1, but we defer the details to 
subsequent chapters. The rest of this chapter then discusses issues that are relevant to all types of 
predictive model, regardless of the specific form, such as evaluation. 


14.1.1 Types of model 


There are many different kinds of predictive model p(y|x). The biggest distinction is between 
parametric models, that have a fixed number of parameters independent of the size of the training 
set, and non-parametric models that have a variable number of parameters that grows with the 
size of the training set. Non-parametric models are usually more flexible, but can be slower to use 
for prediction. Parametric models are usually less flexible, but are faster to use for prediction. 

Most non-parametric models are based on comparing a test input x to some or all of the stored 
training examples {£n,n = 1: N}, using some form of similarity, sn = K(a,x,) > 0, and then 
predicting the output using some weighted combination of the training labels, such as y = < SnYn- 
A typical example is a Gaussian process, which we discuss in Chapter 18. Other examples, such as 
K-nearest neighbor models, are discussed in the prequel to this book, [Mur22]. 

Most parametric models have the form p(y|x) = p(y|f(«x;@)), where f is some kind of function 
that predicts the parameters (e.g., the mean, or logits) of the output distribution (e.g., Gaussian 
or categorical). There are many kinds of function we can use. If f is a linear function of @ (i.e., 
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f(x; 0) = 0' d(x) for some fized feature transformation @), then the model is called a generalized 
linear model or GLM, which we discuss in Chapter 15. If f is a non-linear, but differentiable, function 
of 0 (e.g., f(x; 0) = 05 p(x; 01) for some learnable function (a; 6,)), then it is common to represent 
f using a neural network (Chapter 16). Other types of predictive model, such as decision trees and 
random forests, are discussed in the prequel to this book, [Mur22]. 


14.1.2 Model fitting using ERM, MLE and MAP 


In this section, we briefly discuss some methods used for fitting (parametric) models. The most 
common approach is to use maximum likelihood estimation or MLE, which amounts to solving 
the following optimization problem: 


6 = argmax p(D|0) = argmax log p(D|@) (14.1) 
oco oco 
If the dataset is N iid data samples, the likelihood decomposes into a product of terms, p(D|0@) = 
I~, P(Yn|Ln, 0). Thus we can instead minimize the following (scaled) negative log likelihood: 


N 


A . ol 
6 = argmin > XC [- log p(ynlan, 9)] (14.2) 


0690 n=1 


We can generalize this by replacing the log loss ¢,,(@) = — log p(Yn|£n,0) with a more general 
loss function to get 


6 = argmin r (0) (14.3) 
oco 


where r(0) is the empirical risk 


r(0) = ~ So én(8) (14.4) 


34 This approach is called empirical risk minimization or ERM. 


ERM can easily result in overfitting, so it is common to add a penalty or regularizer term to get 


6 = argminr(0) + AC(@) (14.5) 
oco 


~~ where A > 0 controls the degree of regularization, and C (0) is some complexity measure. If we use 
~ log loss, and we define C (0) = — log m(0), where mo(0) is some prior distribution, and we use À = 1, 


— we recover the MAP estimate 


A 16 [à IÈ IG Is 


6 = argmax log p(D|@) + log 70(8) (14.6) 
eco 


46 This can be solved using standard optimization methods (see Chapter 6). 
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14.1. INTRODUCTION 


14.1.3 Model fitting using Bayes, VI and generalized Bayes 


Another way to prevent overfitting is to estimate a probability distribution over parameters, q(@), 
instead of a point estimate. That is, we can try to estimate the ERM in expectation: 


q = argmin E,,) [r(9)] (14.7) 
qEP(O) 


If P(O) is the space of all probability distributions over parameters, then the solution will converge 
to a delta function that puts all its probability on the MLE. Thus this approach, on its own, will 
not prevent overfitting. However, we can regularize the problem by preventing the distribution from 
moving too far from the prior. If we measure the divergence between q and the prior using KL 
divergence, we get 


7 a 1 
q = argmin E49) [r(9)] + Dux (a || 70) (14.8) 
qEP(O) 


The solution to this problem is known as the Gibbs posterior, and is given by the following: 


e (O) To (8) 


(0) = 14.9 
â( ) fer) 79 (0')d0’ ( ) 
This is widely used in the PAC-Bayes community (see e.g., [Alq21]. 
Now suppose we use log loss, and set A = N, to get 
— Eha log p(yn lan.) q (0 
40) = — —70(8) (14.10) 
fe Vwi log p(Yn|En,0 )9(0’) do’ 
Then the resulting distribution is equivalent to the Bayes posterior: 
D|O)10(0 
(0) p(D|@)r0(8) (14.11) 


— J p(D|@")m0(8") a6" 


Often computing the Bayes posterior is intractable. We can simplify the problem by restricting 
attention to a limited family of distributions, Q(0) c P(O). This gives rise to the following objective: 


q = argmin E,,) |- log p(D|@)| + Dru (4 || 70) (14.12) 
qeQ(O) 


This is known as variational inference; see Chapter 10 for details. (See also Section 6.7, where we 
discuss the Bayesian learning rule.) 

We can generalize this by replacing the negative log likelihood with a general risk, r(@). Furthermore, 
we can replace the KL with a general divergence, D(q||7o), which we can weight using a general A. 
This gives rise to the following objective: 


q = argmin E,,@) [r(@)] + AD(q||70) (14.13) 
qEQ(O) 


This is called generalized Bayesian inference [BHW16; KJD19; KJD21]. 
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14.2 Evaluating predictive models 


In this section we discuss how to evaluate the quality of a trained discriminative model. 


14.2.1 Proper scoring rules 


It is common to measure performance of a predictive model using a proper scoring rule [GRO7al, 
which is defined as follows. Let S(pe, (y, x£)) be the score for predictive distribution pe(y|x) when 
given an event y|æ ~ p*(y|x), where p* is the true conditional distribution. (If we want to evaluate 
a Bayesian model, where we marginalize out 0 rather than condition on it, we just replace pg (y|x) 
with p(y|x) = f pe(y|x)p(@|D)d@.) The expected score is defined by 


Sep, = J p'(@)p" (ule) S (pe, (y, #) )dyde (14.14) 


A proper scoring rule is one where S(pg, p*) < S(p*, p*), with equality iff pe(y|a) = p* (yļæ). Thus 
maximizing such a proper scoring rule will force the model to match the true probabilities. 

The log-likelihood, S(pe, (y, z)) = log pe(y|x), is a proper scoring rule. This follows from Gibbs 
inequality: 


Therefore minimizing the NLL (aka log loss) should result in well-calibrated probabilities. However, 
in practice, log-loss can over-emphasize tail probabilities [QC+06]. 
A common alternative is to use the Brier score [Bri50], which is defined as follows: 


Cc 
S(po, (v, 2)) = = (poly = ele) -1 (y = 0)? (14.16) 


= This is just the squared error of the predictive distribution p = p(1 : C|a) compared to the one-hot 
2- label distribution y. Since it based on squared error, the Brier score is less sensitive to extremely 
22 rare or extremely common classes. The Brier score is also a proper scoring rule. 


z- 14.2.2 Calibration 


N Ie IO IO Iœ IN IQ o 


36 A model whose predicted probabilities match the empirical frequencies is said to be calibrated 
37 [Daw82; NMC05; Guo+17]. For example, if a classifier predicts p(y = c|æ) = 0.9, then we expect this 
38 to be the true label about 90% of the time. A well-calibrated model is useful to avoid making the 
39 wrong decision when the outcome is too uncertain. In the sections below, we discuss some ways to 
40 measure and improve calibration. 


— 14.2.2.1 Expected calibration error 


IS IS IS IE lS 
N IO [or e | 


To assess calibration, we divide the predicted probabilities into a finite set of bins or buckets, and then 


45 assess the discrepancy between the empirical probability and the predicted probability by counting. 
46 More precisely, suppose we have B bins. Let B, be the set of indices of samples whose prediction 
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Figure 14.1: Reliability diagrams for the ResNet CNN image classifier [He+16b] applied to CIFAR-100 
dataset. ECE is the expected calibration error, and measures the size of the red gap. Methods from left to 
right: original probabilities; after temperature scaling; after histogram binning; after isotonic regression. From 
Figure 4 of [Guo+17]. Used with kind permission of Chuan Guo. 


confidence falls into the interval J, = (23*, 4]. Here we use uniform bin widths, but we could also 


define the bins so that we can get an equal number of samples in each one. 
Let f(@)e = ply = c|æ), In = argmaXce{1,...,C} F(En)c, and Pn = MaXce{1,...,C} f(€n)c. The 
accuracy within bin b is defined as 


acc(By) = 5i XO 1 (Gn = Yn) (14.17) 


neBy 


The average confidence within this bin is defined as 


conf(By,) = = Bn (14.18) 


If we plot accuracy vs confidence, we get a reliability diagram, as shown in Figure 14.1. The 
gap between the accuracy and confidence is shown in the red bars. We can measure this using the 
expected calibration error (ECE) [NCH15]: 


B 
ECE(f a= IB Face |acc(B,) — conf(By)| (14.19) 

In the multiclass case, the ECE only looks at the error of the MAP (top label) prediction. We 
can extend the metric to look at all the classes using the marginal calibration error, proposed in 
[KLM19]: 


c 
MCE = X wE [(p(¥ = clf(@)c) — f(#)c)”| (14.20) 
z S Boel > 
= >». We bD B (acc(Bb,c) — conf(By,c)) (14.21) 
c=1 b=1 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Io 1B lwo IN IR 


IR IIS IS IR IS Ie Ie IR le lale le Is IF ls 


550 


where 6,,- is the b’th bin for class c, and we € [0,1] denotes the importance of class c. (We can set 
We = 1/C if all classes are equally important.) In [Nix+19], they call this metric static calibration 
error; they show that certain methods that have good ECE may have poor MCE. Other multi-class 
calibration metrics are discussed in [WLZ19]. 


14.2.2.2 Improving calibration 


In principle, training a classifier so it optimizes a proper scoring rule (such as NLL) should auto- 
matically result in a well-calibrated classifier. In practice, however, unbalanced datasets can result 
in poorly calibrated predictions. Below we discuss various ways for improving the calibration of 
probabilistic classifiers, following [Guo-+17]. 


14.2.2.3 Platt scaling 


Let z be the log-odds, or logit, and p = o(z), produced by a probabilistic binary classifier. We wish 
to convert this to a more calibrated value q. The simplest way to do this is known as Platt scaling, 
and was proposed in [Pla00]. The idea is to compute q = o(az + b), where a and b are estimated via 
maximum likelihood on a validation set. 

In the multiclass case, we can extend Platt scaling by using matrix scaling: q = softmax(Wz + b), 
where we estimate W and b via maximum likelihood on a validation set. Since W has K x K 
parameters, where K is the number of classes, this method can easily overfit, so in practice we restrict 
W to be diagonal. 


14.2.2.4 Nonparametric (histogram) methods 


26 Platt scaling makes a strong assumption about how the shape of the calibration curve. A more 
2° flexible, nonparametric, method is to partion the predicted probabilities into bins, pm, and to 
28 estimate an empirical probability qm for each such bin; we then replace pm with qm; this is known 
29 as histogram binning [ZEOla]. We can regularize this method by requiring that q = f(p) bea 
2° piecewise constant, monotonically non-decreasing function; this is known as isotonic regression 
31 [ZEOla]. An alternative approach, known as the scaling-binning calibrator, is to apply a scaling 
32 method (such as Platt scaling), and then to apply histogram binning to that. This has the advantage 


of using the average of the scaled probabilities in each bin instead of the average of the observed 


34 binary labels (see Figure 14.2). In [KLM19], they prove that this results in better calibration, due to 


the lower variance of the estimator. 
In the multiclass case, z is the vector of logits, and p = softmax(z) is the vector of probabilities. 


37 We wish to convert this to a better calibrated version, q. [ZEO1b] propose to extend histogram 


binning and isotonic regression to this case by applying the above binary method to each of the K 
one-vs-rest problems, where K is the number of classes. However, this requires K separate calibration 


2° models, and results in an unnormalized probability distribution. 


— 14.2.2.5 Temperature scaling 


In [Guo+17], they noticed empirically that the diagonal version of Platt scaling, when applied to 
a variety of DNNs, often ended learning a vector of the form w = (c,c,...,c), for some constant c. 


46 This suggests a simpler form of scaling, which they call temperature scaling: q = softmax(z/T), 
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Uncalibrated Model Output 


Figure 14.2: Visualization of 3 different approaches to calibrating a binary probabilistic classifier. Black crosses 
are the observed binary labels, red lines are the calibrated outputs. (a) Platt scaling. (b) Histogram binning 
with 3 bins. The output in each bin is the average of the binary labels in each bin. (c) The scaling-binning 
calibrator. This first applies Platt scaling, and then computes the average of the scaled points (gray circles) in 
each bin. From Figure 1 of [KLM19]. Used with kind permission of Ananya Kumar. 
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Figure 14.8: Softmaz distribution softmax(a/T), where a = (3,0,1), at temperatures of T = 100, T = 2 
and T = 1. When the temperature is high (left), the distribution is uniform, whereas when the temperature 
is low (right), the distribution is “spiky”, with most of its mass on the largest element. Generated by 
softmax_ plot.ipynb. 


where T > 0 is a temperature parameter, which can be estimated by maximum likelihood on the 
validation set. The effect of this temperature parameter is to make the distribution less peaky, as 
shown in Figure 14.3. [Guo+17] show empirically that this method produces the lowest ECE on a 
variety of DNN classification problems (see Figure 14.1 for a visualization). Furthermore, it is much 
simpler and faster than the other methods. 

Note that Platt scaling and temperature scaling do not affect the identity of the most probable 
class label, so these methods have no impact on classification accuracy. However, they do improve 
calibration performance. A more recent multi-class calibration method is discussed in [Kul+19]. 


14.2.2.6 Label smoothing 


When training classifiers, we usually represent the true target label as a one-hot vector, say y = (0, 1,0) 
to represent class 2 out of 3. We can improve results if we “spread” some of the probability mass 
across all the bins. For example we may use y = (0.1,0.8,0.1). This is called label smoothing and 
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often results in better-calibrated models [MKH19]. 


14.2.2.7 Bayesian methods 


Bayesian approaches to fitting classifiers often result in more calibrated predictions, since they 
represent uncertainty in the parameters. See Section 17.3.8 for an example. However, [Ova+19] 
shows that well-calibrated models (even Bayesian ones) often become mis-calibrated when applied to 
inputs that come from a different distribution (see Section 19.2 for details). 


14.2.3 Beyond evaluating marginal probabilities 


Calibration (Section 14.2.2) focuses on assessing properties of the marginal predictive distribution 
p(y|z). But this can sometimes be insufficient to distinguish between a good and bad model, especially 
in the context of online learning and sequential decision making, as pointed out in [Lu+22; Osb+21; 
WSG21; KKG22]. For example, consider two learning agents who observe a sequence of coin tosses. 
Let the outcome at time t be Y, ~ Ber(@), where 0 is the unknown parameter. Agent 1 believes 
0 = 2/3, whereas agent 2 believes either 0 = 0 or 0 = 1, but is not sure which, and puts probabilities 
1/3 and 2/3 on these events. Thus both agents, despite having different models, make identical 
predictions for the next outcome: p(Y} = 0) = 1/3 for agents i = 1,2. However, the predictions of 
the two agents about a sequence of T future outcomes is very different: In particular, agent 1 predicts 
each individual coin toss is a random Bernoulli event, where the probability is due to irreducible 
noise or aleatoric uncertainty: 


1 
PY) =0,...,¥7 =0) = = 


T 


(14.22) 


—— By contrast, agent 2 predicts that the sequence will either be all heads or all tails, where the 


probability is induced by epistemic uncertainty about the true parameters: 


1/3 ify = =y = 
PY? =y.. YF =y) = 4 2/3 ify =---=y,=1 (14.23) 
0 otherwise 


The difference in beliefs between these agents will impact their behavior. For example, in a casino, 
agent 1 incurs little risk on repeatedly betting on heads in the long run, but for agent 2, this would 
be a very unwise strategy, and some initial information gathering (exploration) would be worthwhile. 

Based on the above, we see that it is useful to evaluate joint predictive distributions when assessing 
predictive models. In [Lu+22; Osb+21] they propose to evaluate the posterior predictive distributions 
over T outcomes y = Yr4+1:T+7, given a set of T inputs x = X7.74,-1, and the past T data samples, 


-~ Dr = {(Xt, Yin1) :t = 0,1,...,0— 1}. The Bayes optimal predictive distribution is 


Pr = ply|z, Dr) (14.24) 


23 This is usually intractable to compute. Instead the agent will use an approximate distribution, known 


as a belief state, which we denote by 
Qr = p(y|z, Dr) (14.25) 
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14.2. EVALUATING PREDICTIVE MODELS 


The natural performance metric is the KL between these distributions. Since this depend on the 
inputs æ and Dr = (Xo-7r-1, Yı:r), we will averaged the KL over these values, which are drawn iid 
from the true data generating distribution, which we denote by 


P(X,Y,€) = P(X|€)P(Y|X, €)P(E) (14.26) 


where € is the true but unknown environment. Thus we define our metric as 


dho = Epe, pr) [Dex (P? (yla, Pr) || Q(yla, Pr))] (14.27) 
where 
T+r-1 
P(x, Dr,€ TT P(X,|E)P vinlke) | II Pie) (14.28) 
a re 
P(Dr|€) P(æ|£) 


and P(x, Dr) marginalizes this over environments. 
E it is usually intractable to compute the exact Bayes posterior, PĒ, so we cannot 
evaluate d% BO . However, in Section 14.2.3.1, we show that 


dho = 466 — I(E; y|Dr, x) (14.29) 


where the second term is a constant wrt the agent, and the first term is given by 


dég = Ep(e,pr€) [Dex (P(yle, £) || Q(ylx, Pr))] (14.30) 
_f P(ylx, €) 
= Ep(y|x,€)P(@,Dr,€) hog pe (14.31) 


Hence if we rank us in terms of dÉ% Q» it will give the same results as ranking them by d% D- 

To compute dL Q in practice, we can use a Monte Carlo approximation: we just have to sample 
J environments, Ef ~ P(E), sample a training set Dr from each environment, DJ, ~ P(Dr|E’), 
and then sample N data vectors of length 7, (£2, Y2) ~ P(XT:T4r-1, Yr+1:T4+7|E1). We can then 
compute 


J N 
> 1 
d= Fy a> [log P(yilæ}, EŻ) — log Qly} |x}, Di) (14.32) 
j=i n=l 
where 
, T+r—1 , 
Pin = P(y)|x3,, €7) = lI P( O al ae t) 1) (14.33) 
din = Qly |e}, Di.) = Ayia}, 0)QOD})dð (14.34) 
1 M T+r-1 
oe M 5 II Q(Y, n Taal XÍ ,,63,) (14.35) 
m=1 t=T 
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where 63, ~ Q(0|D%,) is a sample from the agent’s posterior over the environment. 

The above assumes that P(Y |X) is known; this will be the case if we use a synthetic data generator, 
as in the the “neural testbed” in [Osb+21]. If we just have an J empirical distributions for P(X, Y), 
we can replace the KL with the cross entropy, which only differs by an additive constant: 


dé 6 = Ep(e,pr£) Dex (P(yle, €) || Q(ylx, Pr))] (14.36) 
= Epve,y,e) log P(y|x, €)] — E p(y, prispe) log Q(y|z, Pr)] (14.37) 
—_—_—_—-_—_—_—_ NO 
const A 


where the latter term is just the empirical negative log likelihood (NLL) of the agent on samples 
from the environment. Hence if we rank agents in terms of their NLL or cross entropy dg we will 
get the same results as ranking them by dg 6: which will in turn give the same results as ranking 
them by d% o>: 

In practice we can approximate the cross entropy as follows: 


J N 
` 1 SO 
d£ o = -7y DD 08 Avhleh, Dr) (14.38) 
j=1n=1 
where Di, ~ PÍ, and (æi, ył) ~ PI. 
An alternative to estimating the KL or NLL is to evaluate the joint predictive accuracy by using it 
in a downstream task. In [Osb+21], they show that good predictive accuracy (for T > 1) correlates 


with good performance on a bandit problem (see Section 34.4). In [WSG21] they show that good 
predictive accuracy (for 7 > 1) results in good performance on a transductive active learning task. 


— 14.2.3.1 Proof of claim 
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28 We now prove Equation (14.29), based on [Lu+21a]. First note that 


, P(y|x, €) 
dEL =E log —— 14. 
£Q P(æ,Dr,£)P(y|æ,£) og Olylx, Dr) (14.39) 
n e] | P(y|x, £) | 
= E |lo HE | lo 14.40 
hee Seb) € Pyle, Dr) AN 
3, For the first term in Equation (14.40) we have 
a et P(y|x, Dr) 
Y tog | S P(x,y, Dr) log = 22T 14.41 
hos Soe p| = E Pv: Pe) oe Gare De ee 
P(y|x,Dr) 
= Y` P(a#,D P(ylx, Dr) log Yer’ 14.42 
= Epæ,pr) [Dx (P(y|x, Dr) || Qlyle, Pr))] = d5% (14.43) 


We now show that the second term in Equation (14.40) reduces to the mutual information. We 


~ exploit the fact that 


PE, ulDr, a) 


(14.44) 
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14.3. CONFORMAL PREDICTION 


N 3 l a) nce -Moi 
s aoe t, 
fox squirrel squirrel, ox, Bucket, parrer marmo squirrel, mink, weasel, beaver, polecat 
0.99 fe 


Figure 14.4: Prediction set examples on Imagenet. We show three progressively more difficult examples of the 
class fox squirrel and the prediction sets generated by conformal prediction. From Figure 1 of [AB21]. Used 
with kind permission of Anastasios Angelopoulos. 


since Dr has no new information in beyond E. From this we get 


E |log —-~*—— | =E }lo 14.45 
bee pipe Ds) 8 PWD, Dr) U4) 
P(E, y|Dr, £) 
= P(Dr, x P(E, y|Dr, x) lo 14.46 
= I(E; y|Dr, 2) (14.47) 
Hence 
dé 6 = ds + I(E; y|Dr, æ) (14.48) 
as claimed. 


14.3 Conformal prediction 


In this section, we briefly discuss conformal prediction [VGS05; SV08; ZFV20; AB21; KSB21; 
Man22]. This is a simple but effective way to create prediction intervals or sets with guaranteed 
frequentist coverage probability from any predictive method p(y|a). This can be seen as a form 
of distribution free uncertainty quantification, since it works without making assumptions 
(beyond exchangeability of the data) about the true data generating process or the form of the 
model.' Our presentation is based on the excellent tutorial of [AB21].? 

In conformal prediction, we start with some heuristic notion of uncertainty — such as the softmax 
score for a classification problem, or the variance for a regression problem — and we use it to define 
a conformal score s(xz,y) € R, which measures how badly the output y “conforms” to x. (Large 


1. The exchangeability assumption rules out time series data, which is serially correlated. However, extensions to 
conformal prediction have been developed for the time series case, see e.g., [Zaf+22]. The exchangeability assumption 
also rules out distribution shift, although this has also been partially addressed. 

2. See also the easy-to-use MAPIE Python library at https: //mapie.readthedocs.io/en/latest/index.html, and 
the list of papers at https: //github.com/valeman/awesome-conformal-prediction. 
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values of the score are less likely, so it is better to think of it as a non-conformity score.) Next we 
apply this score to a calibration set of n labeled examples, that was not used to train f, to get 
S = {si = s(aj, yi) :i=1:n}.° The user specifies a desired confidence threshold a, say 0.1, and we 
then compute the (1 — a) quantile ĝ of S. (In fact, we should replace 1 — a with MESES to 
account for the finite size of S.) Finally, given a new test input, £n+1, we compute the e N set 
to be 


T (@n41) = {y : 8(@n41,¥) < ô} (14.49) 


Intuitively, we include all the outputs y that are plausible given the input. See Figure 14.4 for an 
illustration. 
Remarkably, one can show the following general result 


1-a < P*(y"™ € T(ans1)) <1—a k (14.50) 


where the probability is wrt the true distribution P*(£n+1, Yn+1). We say that the prediction set has 
a coverage level of 1 — a. This holds for any value of n > 1 and a € [0,1]. The only assumption is 
that the values (a;, y;) are exchangeable, and hence the calibration scores s; are also exchangeable. 

To see why this is true, let us sort the scores so s1 < -++ Sn, S0 ĝ = si, where i = [erva], (We 
assume the scores are distinct, for simplicity.) The score s,41 is equally likely to fall i in anywhere 


between the calibration points s;,...,5,, since the points are exchangeable. Hence 
P*(Sn41 < Sk) = = (14.51) 
T n+1 


for any k € {1,...,n +1}. The event {yn+1 E€ T(&n41)} is equivalent to {5,41 < ĝ}. Hence 


[n+ D-a), 
n+1 E 


P*(Yn+1 € T(n41)) = P*(sn41 <Q = a (14.52) 


= For the proof of the upper bound, see [Lei+18]. 


Although this result may seem like a “free lunch”, it is worth noting that we can always achieve 


— a desired coverage level by defining the prediction set to be all possible labels. In this case, the 
= prediction set will be independent of the input, but it will cover the true label 1 — a of the time. To 


rule out some degenerate cases, we seek prediction sets that are as small as possible (although we 


= allow for the set to be larger for harder examples), while meeting the coverage requirement. Achieving 


this goal requires that we define suitable conformal scores. Below we give some examples of how to 
compute conformal scores s(x, y) for different kinds of problem.* 


14.3.1 Conformalizing classification 


40 The simplest way to apply conformal prediction to multiclass classification is to derive the conformal 
41 score from the softmax score assigned to the label using s(x,y) = 1 — f(æ)y, so large values are 


3. Using a calibration sety is called split conformal prediction. If we don’t have enough data to adopt this 
splitting approach, we can use full conformal prediction [VGS05], which requires fitting the model n times using a 
leave-one-out type procedure. 


=" 4. It is also possible to learn conformal scores in an end-to-end way, jointly with the predictive model, as discussed in 


A I8 


46 [Stu+22]. 
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prediction set 


(a) 


Figure 14.5: (a) Illusration of adaptive prediction set. From Figure 5 of [AB21]. Used with kind permission 
of Anastasios Angelopoulos. (b) Illustrate of conformalized quantile regression. From Figure 6 of [AB21]. 
Used with kind permission of Anastasios Angelopoulos. (c) Illustration of pinball loss function. 


considered less likely than small values. We compute the threshold ĝ as described above, and then we 
define the prediction set to be T (æ) = {y : f(a), > 1— å}, which matches Equation (14.49). That is, 
we take the set of all class labels above the specified threshold, as illustrated in Figure 14.4. 
Although the above approach produces prediction sets with the smallest average size (as proved in 
[SLW19]), the size of the set tends to be too large for easy examples and too small for hard examples. 
We now present an improved method, known as adaptive prediction sets, due to [RSC20], which 
solves this problem. The idea is simple: we sort all the softmax scores, f(x). for c= 1: C, to get 
permutation 71.¢, and then we define s(x, y) to be the cumulative sum of the scores up until we 
reach label y: s(a,y) = 5i f(£)n;, where k = my. We now compute ĝ as before, and define the 
prediction set 7 (x) to be the set of all labels, sorted in order of decreasing probability, until we 
cover ĝ of the probability mass. See Figure 14.5a for an illustration. This uses all the softmax scores 
output by the model, rather than just the top score, which accounts for its improved perforance. 


14.3.2 Conformalizing regression 


In this section, we consider conformalized regression problems. Since now y € R, computing the 
prediction set in Equation (14.49) is expensive, so instead we will compute a prediction interval, 
specified by a lower and upper bound. 


14.3.2.1 Conformalizing quantile regression 


In this section, we use quantile regression to compute the lower and upper bounds. We first fit 
a function of the form t,(a), which predicts the y quantile of the pdf P(Y |æ). For example, if we 
set y = 0.5, we get the median. If we use y = 0.05 and y = 0.95, we can get an approximate 90% 
prediction interval using [to.95(x), to.95(x)], as illustrated by the gray lines in Figure 14.5b. To fit 
the quantile regression model, we just replace squared loss with the quantile loss, also called the 
pinball loss, which is defined as 


£,(y,#) = (y — ÊI (y > t) +(ê-y)(1-7)I (y < t) (14.53) 


where y is the true output and f is the predicted value at quantile y. See Figure 14.5c for an 
illustration. 
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The regression quantiles are only approximately a 90% interval because the model may be 
mismatched to the true distribution. However we can use conformal prediction to fix this. In 
particular, let us define the conformal score to be 


s(x, y) = max (ta/2(@) — y, y — ta/2(@)) (14.54) 


In other words, s(x, y) is a positive measure of how far the value y is outside the prediction interval, 
or is a negative measure if y is inside the prediction interval. We compute ĝ as before, and define the 
conformal prediction interval to be 


T (2) = [baj2(®) — @,tay2(#) + ål (14.55) 


This makes the quantile regression interval wider if ĝ is positive (if the base method was overconfident), 
and narrower if ĝ is negative (if the base method was underconfident). See Figure 14.5b for an 
illustration. This approach is called conformalized quantile regression or CQR [RPC19]. 


14.3.2.2 Conformalizing predicted variances 


There are many ways to define uncertainty scores u(a), such as the predicted standard deviation, 
from which we can derive a prediction interval using 


T(x) = [f(w) — u(w)d, f(x) + u(a)d] (14.56) 
Here ĝ is derived from the quantiles of the following conformal scores 


s(æ, y) = HEI Aal (14.57) 


The interval produced by this method tends to be wider than the one computed by CQR, since it 


29 extends an equal amount above and below the predicted value f(a). In addition, the uncertainty 
30 measure u(x) may not scale properly with a. Nevertheless, this is a simple post-hoc method that 
31 can be applied to many regression methods without needing to retrain them. 
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1 5 Generalized linear models 


15.1 Introduction 


A generalized linear model or GLM [MN89] is a conditional version of an exponential family 
distribution (Section 2.3). More precisely, the model has the following form: 


n n— A n 
P(Yn|@n, w, 0°) = exp a — Cn) + log h(n, 0°) (15.1) 


where nn = w' £n is the natural parameter for the distribution, A(7,) is the log normalizer, T(y) = y 
is the sufficient statistic, and ø? is the dispersion term. Based on the results in Section 2.3.3, we can 
show that the mean and variance of the response variable are as follows: 


lin = E [yn|atn, w, 0°] = A’(nn) = (m) (15.2) 
Vv [yn|£n, Ww, o°] = A" (nn) o? (15.3) 


We will denote the mapping from the linear inputs to the mean of the output using un = 8 t(n), 
where the function @ is known as the link function, and 4t is known as the mean function. This 
relationship is usually written as follows: 


(in) = mn = W En (15.4) 


15.1.1 Examples 


In this section, we give some examples of widely used GLMs. 


15.1.1.1 Linear regression 


Recall that linear regression has the form 


1 1 
2) — T 2 
PlYn|En, w, 0 ) — Vno? exp( 2g2 (Yn w £n) ) (15.5) 
Hence 
2 1 2 1 2 
log p(yn|£n, w, o ) = (Yn Mn) log(2ro ) (15.6) 
20? 2 
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1 

2 where 7 = w'a,. We can write this in GLM form as follows: 

3 2 

3 na 2 

niln 2 1 n 

4 log p(yn|an, w, 0°) = dill 2 (4 + log(2n0*)) (15.7) 
5 o 2\o 

E We see that A(n) = 72/2 and hence 

z 

8 2 [Yn] = Nn. = W En (15 8 
g V [yn] Z o? (15.9 
10 

11 See Section 15.2 for details on linear regression. 

12 

13 15.1.1.2 Binomial regression 

4 

i If the response variable is the number of successes in Np trials, yn € {0,..., Nn}, we can use 
ie binomial regression, which is defined by 

17 Plyn|En, Nn, w) = Bin(yn|o(w" en), Nn) (15.10) 
18 

o We see that binary logistic regression is the special case when N,, = 1. 

20 The log pdf is given by 

21 Nn 

22 log p(Yn|a@n, Nn, w) = Yn log Un + (Nn — Yn) log(1 — Un) + log e ) (15.11) 
23 

23 Un Nn 

24 = Yn log( ) + Nn log(1 — un) + log (15.12) 
25 p= Hn Yn 
26 where un = 0(7,). To rewrite this in GLM form, let us define 

27 
Te F 
28 A Ln _ 1 Deri al 1 O T 
29 Nn = log AS] = log T+ ewe, ewan, | = log re = wW Trn (15.13) 
30 
31 Hence we can write binomial regression in GLM form as follows 

32 

33 PY nln, Nn, Ww) = Ynn = Aln) + hlyn) (15.14) 
34 Nn 
35 where h(yn) = log ) and 
36 
37. A(t) = —Nn log(1 — Un) = Nn log(1 + e”) (15.15) 
38 Hence 
= dA Nne™ N 

40 z fyn] = = _ = "= Nalin 15.16 
a [Yn] i Ten Ie a ( ) 
42 and 

43 
z dA 
a V [yn] = = Nnbn(1 — Hn) (15.17) 
45 dna 
46 See Supplementary Section 15.3 for an example of binomial regression. 
47 
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15.1.1.3 Poisson regression 


If the response variable is an integer count, yn € {0,1,...}, we can use Poisson regression, which 
is defined by 


P(Yn|@n, w) = Poi(yn|exp(w' an)) (15.18) 
where 
f py 
Poi(yļu) =e" a (15.19) 


is the Poisson distribution. Poisson regression is widely used in bio-statistical applications, where 
Yn Might represent the number of diseases of a given person or place, or the number of reads at a 


IIS IS IS IS IS IB IS IS IR IS IG IS IE 5 9 1% I lm la e Teo Ito Ie 
A Iw IN |e [© ko iœ N ie io 


13 genomic location in a high-throughput sequencing context (see e.g., [Kua+09]). 
14 The log pdf is given by 
15 
16 log p(yn |En, w) = Yn log Hn — Hn — log(yn!) (15.20) 
T where Un = exp(w!'a,). Hence in GLM form we have 
log p(yn|£n, w) = Ynn — A(n) + h(yn) (15.21) 
where mn = log( un) = w'an, A(m) = Un = e", and h(yn) = —log(yn!). Hence 
dA 
7 ee 15.22 
[yn] drm e Hn ( ) 
and 

25 
26 : i 
aT Wl = o h (15.23) 
28 
29 15.1.1.4 Zero-inflated Poisson regression 
30 
3i In many forms of count data, the number of observed Os is larger than what a model might expect, 
32 even after taking into account the predictors. Intuitively, this is because there may be many ways 
33 to produce no outcome. For example, consider predicting sales data for a product. If the sales are 
~ 0, does it mean the product is unpopular (so the demand is very low), or was it simply sold out 
34, : o G so OE ; ; 
m (implying the demand is high, but exceed supply)? Similar problems arise in genomics, epidemiology, 
za ete. 
36 
= To handle such situations, it is common to use a zero-inflated Poisson or ZIP model. The 
ae likelihood for this model is a mixture of two distributions: a spike at 0, and a standard Poisson. 
m Formally, we define 
40 ; _ 
a p+(1—p)exp(—A) ify=0 

ZIP(y|p, A) = È 15.24 
42 a 
43 Here p is the prior probability of picking the spike, and A is the rate of the Poisson. We see that 
44 there are two “mechanisms” for generating a 0: either (with probability p) we chose the spike, or 
45 (with probability 1 — p) we simply generate a zero count just because the rate of the Poisson is so 
46 low. (This latter event has probability \°e~*/0! = e~>.) 
47 
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15.1.2 GLMs with non-canonical link functions 


We have see how the mean parameters of the output distribution are given by u = €~!(7), where the 
function £ is the link function. There are several choices for this function, as we now discuss. 

The canonical link function £ satisfies the property that 0 = (u), where 6 are the canonical 
(natural) parameters. Hence 


9 = Lu) =n) =n (15.25) 


This is what we have assumed so far. For example, for the Bernoulli distribution, the canonical 
parameter is the log-odds 7 = log(u/(1 — w)), which is given by the logit transform 


n = L(u) = logit(u) = log (+) (15.26) 


The inverse of this is the sigmoid or logistic funciton 


p= O(n) = o(n) = 1/(1 + e7") (15.27) 
However, we are free to use other kinds of link function. For example, in Section 15.4 we use 

n= E(u) = $~ (u) (15.28) 
u =L (n) = O(n) (15.29) 


This is known as the probit link function. 
Another link function that is sometimes used for binary responses is the complementary log-log 


26 function 
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n = Ku) = log(— log(1 — 1)) (15.30) 


This is used in applications where we either observe 0 events (denoted by y = 0) or one or more 
(denoted by y = 1), where events are assumed to be governed by a Poisson distribution with rate A. 
Let E be the number of events. The Poisson assumption means p(£ = 0) = exp(—A) and hence 


ply = 0) = (1— u) = p(E = 0) = exp(—) (15.31) 


35 Thus A = — log(1 — 4). When 4 is a function of covariates, we need to ensure it is positive, so we use 
36 ) = e", and hence 


n = log(A) = log(— log(1 — 1) (15.32) 


40 15.1.3 Maximum likelihood estimation 


42 GLMs can be fit using similar methods to those that we used to fit logistic regression. In particular, 


the negative log-likelihood has the following form (ignoring constant terms): 
IA 
NLL(w) = — log p(D|w) = -— X fn (15.33) 
4 n=1 
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15.1. INTRODUCTION 


where 


Ln = MnYn — A(n) (15.34) 


where 7, = wl æ. For notational simplicity, we will assume g? = 1. 
We can compute the gradient for a single term as follows: 


a ln — ln Om 


= — 4 — — 
Jn = ðw Om ðw (Yn — A (Mn))En = (Yn — Hn)En (15.35) 


where un = f(w'x,), and f is the inverse link function that maps from canonical parameters to 


mean parameters. (For example, in the case of logistic regression, we have un = 7(w'2).) 


The Hessian is given by 


8? X gn 
H = Jugag NL) — p Du (15.36) 
where 
agn agn o nm 
T S = -gn f (w £n), (15.37) 
Hence 
N 
H= X f (m)ensr, (15.38) 
n=l 


For example, in the case of logistic regression, f(m) = a(n) = Hn, and f(m) = Hn(l— Hn). In 
general, we see that the Hessian is positive definite, since f’(7,) > 0; hence the negative log likelihood 
is convex, so the MLE for a GLM is unique (assuming f(m) > 0 for all n). 

For small datasets, we can use the iteratively reweighted least squares or IRLS algorithm, 
which is a form Newton’s method, to compute the MLE (see e.g., [Mur22, Sec 10.2.6]). For 
large datsets, we can use SGD. (In practice it is often useful to combine SGD with methods that 
automatically tune the step size, such as [Loi+21].) 


15.1.4 Bayesian inference 


To perform Bayesian inference of the parameters, we first need to specify a prior. Choosing a suitable 
prior depends on the form of link function. For example, a “flat” or “uninformative” prior on the 
offset term a € R will not translate to an uninformative prior on the probability scale if we pass @ 
through a sigmoid, as we discuss in Section 15.3.3. 

Once we have chosen the prior, we can compute the posterior using a variety of approximate 
inference methods. For small datasets, HMC (Section 12.5) is the easiest to use, since you just need 
to write down the log likelihood and log prior, and use autograd to compute derivatives and pass 
them to the HMC engine. We give some examples in the following sections. For large datasets, 
assumed density filtering (Section 8.9.3) stochastic variational inference (Section 10.3.1), or more 
specialized algorithms (e.g., [HAB17]) are the best choice. 
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15.2 Linear regression 


Linear regression is the simplest case of a GLM. We gave a detailed introduction to this model in 
the prequel to this book, [Mur22]. In this section, we discuss this model from a Bayesian perspective. 


15.2.1 Conjugate priors 


In this section, we derive the posterior for the parameters of the model when we use a conjugate prior. 
We first consider the case where just w is unknown (so the observation noise variance parameter o? 
is fixed), and then we consder the general case, where both o? and w are unknown. 


15.2.1.1 Noise variance is known 


The conjugate prior for linear regression has the following form: 
p(w) =N (w| Ù, £) (15.39) 


We often use W= 0 as the prior mean and Š= 77Ip as the prior covariance. (We assume the bias 
term is included in the weight vector, but often use a much weaker prior for it, since we typically do 
not want to regularize the overall mean level of the output.) 

To derive the posterior, let us first rewrite the likelihood in terms of an MVN as follows: 


N 
((w) = p(D|w, o°) = [] 2@nlw'z, 07) =N(y|Xw,o7In) (15.40) 


n=1 


where Iy is the N x N identity matrix. We can then use Bayes rule for Gaussians (Equation (2.82)) 


~~ to derive the posterior, which is as follows: 


plw|X, y, 0?) oN (wl %, B)N(ylXw, oLy) = NV (wl @, 8) (15.41) 
ee ee 
a 48 (5 w+ Xy) (15.42) 
(on 
PPPE DE 
HAE '+5X7X)7 (15.43) 


34 where W is the posterior mean, and & is the posterior covariance. 


Now suppose w= 0 and S= 771. In this case, the posterior mean becomes 


~ lẹ o? z 
n= S X'y = (G1 +X'X) X'y (15.44) 


— Tf we define À = cy we see this is equivalent to ridge regression, which optimizes 


L(w) = RSS(w) + djl] |? (15.45) 


43 where RSS is the residual sum of squares: 


1 N 


1 
= Tao \2_ 2 
RSS(w) = 5 > (Yn — w En) = z Xw ylli = 


n=1 


(Xw — y)' (Xw — y) (15.46) 


1 
2 
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15.2. LINEAR REGRESSION 


15.2.1.2 Noise variance is unknown 


In this section, we assume w and g? are both unknown. The likelihood is given by 
i g 
L(w, 0?) = p(D|w, 0?) x (0?)~%?/? exp (-= (Yn — w'en) (15.47) 


Since the regression weights now depend on o? in the likelihood, the conjugate prior for w has the 
form 


p(wlo”) = N(w| ®,0? X) (15.48) 


For the noise variance a7, the conjugate prior is based on the inverse Gamma distrbution, which has 
the form 
va 


b 
r(a) 
(See Section 2.2.3.4 for more details.) Putting these two together, we find that the joint conjugate 
prior is the normal inverse Gamma distribution: 


NIG(w,0°| ©, 5, %, 5) € N(w| %, 0? X)IG(o?| a, b) (15.50) 
a (92) -(#+(D/2)+1) 
(2m)? /?| X |2T (a) 


(w— w Š (w— H) 4.25 


IG(o?| @, 6) = (o?) +) exp(-) (15.49) 


x exp 552 (15.51) 

This results in the following posterior: 
p(w,o?|D) = NIG(w, o?| @, 5, @, b) (15.52) 
@ =£ (Sw +xTy) (15.53) 
$ = (Š +xT™x) (15.54) 
@ =ăŭ +N/2 (15.55) 
ee ees (ws ù% +y"y- aS D) (15.56) 


The expressions for # and Í are similar to the case where g? is known. The expression for @ is also 
intuitive, since it just updates the counts. The expression for } can be interpreted as follows: it is 
the prior sum of squares, b, plus the empirical sum of squares, y'y, plus a term due to the error in 
the prior on w. 

The posterior marginals are as follows. For the variance, we have 


p(o?|D) = f plwla®,D)p(o|D)dw = 1G(o*| @, 5) (15.57) 
For the regression weights, it can be shown that 


p(w|D) = f plwla®, D)p(o?|D)da® = T(w| @, = 8,24) (15.58) 
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15.2.1.3 Posterior predictive distribution 


In machine learning we usually care more about uncertainty (and accuracy) of our predictions, not 
our parameter estimates. Fortunately, one can derive the posterior predictive distribution in closed 
form. In particular, one can show that, given N’ new test inputs X, we have 


pGXD) = | | WX, wo? p(w, 0?|D)durdo? (15.59) 
= J J N(g|Xw, 07 Ly )NIG(w, 0°| ®, £, @, b)dwdo? (15.60) 
= TWX @, (Iv +X BR),24) (15.61) 


The posterior predictive mean is equivalent to “normal” linear regression, but where we plug in 
® = E [w|D] instead of the MLE. The posterior predictive variance has two components: b/a@ly 


due to the measurement noise, and 6/ aX £ XT due to the uncertainty in w. This latter term varies 
depending on how close the test inputs are to the training data. The results are similar to using a 
Gaussian prior (with fixed ¢?), except the predictive distribution is even wider, since we are taking 


into account uncertainty about o°. 


15.2.2  Uninformative priors 


A common criticism of Bayesian inference is the need to use a prior. This is sometimes thought to 
“pollute” the inferences one makes from the data. We can minimize the effect of the prior by using an 
uninformative prior, as we discussed in Section 3.6. Below we discuss various uninformative priors 


= for linear regression. 


59 15.2.2.1 Jeffreys prior 


~ From Section 3.6.3.1, we know that the Jeffreys prior for the location parameter has the form 
2% p(w) œx 1, and from Section 3.6.3.2, we know that the Jeffreys prior for the scale factor has the 


= form p(o) x a. We can emulate these priors using an improper NIG prior with #= 0, Y= ool, 
= G= —D/2 and b= 0. The corresponding posterior is given by 
p(w, o?|D) = NIG(w, o?| @, 5, @, 5) (15.62) 
© = Wmie = (XTX) X'y (15.63) 
$= (x'x)t4C (15.64) 
en 
a=5 (15.65) 
b= - (15.66) 
2 E lly — àl? (15.67) 
v=N D (15.68) 
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Hence the posterior distribution of the weights is given by 
p(w|D) = T (ww, s°C, v) (15.69) 
where W is the MLE. The marginals for each weight therefore have the form 


plwalD) = T(waļûa, s Ojä, v) (15.70) 


15.2.2.2 Connection to frequentist statistics 


Interestingly, the posterior when using Jeffrey’s prior is formally equivalent to the frequentist 
sampling distribution of the MLE, which has the form 


P(éa\D*) = T (alwa, s°Caa, v) (15.71) 


where D* = (X, y*) is hypothetical data generated from the true model given the fixed inputs X. In 
books on frequentist statistics, this is more commonly written in the following equivalent way (see 
e.g., [Ric95, p542]): 


Wa — Wd 
—_ ~ tn_p 15.72 
Gua (15.72) 


The sampling distribution is numerically the same as the posterior distribution in Equation (15.70) 
because 7 (w|u,07,v) = T(p|w, 07, v). However, it is semantically quite different, since the sampling 
distribution does not condition on the observed data, but instead is based on hypothetical data 
drawn from the model. See [BT73, p117] for more discussion of the equivalences between Bayesian 
and frequentist analysis of simple linear models when using uninformative priors. 


15.2.2.3 Zellner’s g-prior 


It is often reasonable to assume an uninformative prior on co, since that is just a scalar that does not 


have much influence on the results, but using an uninformative prior for w can be dangerous, since 
the strength of the prior controls how well regularized the model is, as we know from ridge regression. 

A common compromise is to use an NIG prior with @= —D/2, 6= 0 (to ensure p(o?) x 1) and 
w= 0 and S= g(X'X)~!, where g > 0 plays a role analogous to 1/2 in ridge regression. This is 
called Zellner’s g-prior [Zel86].! We see that the prior covariance is proportional to (X'X)~! rather 
than I; this ensures that the posterior is invariant to scaling of the inputs, e.g., due to a change in 
the units of measurement [Min00a]. 


1. Note this prior is conditioned on the inputs X, but not the outputs y; this is totally valid in a conditional 
(discriminative) model, where all calculations are conditioned on X, which is treated like a fixed constant input. 
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b ~ Normal(0, 10) b ~ LogNormal(0, 1) 


=-- embryo 
== worlds tallest man 


height 
height 


—100 —100 
35 40 45 50 55 35 40 45 50 55 30 40 ‘i 50 60 70 
weight weight weight 


(a) () (c) 


Figure 15.1: Linear regression for predicting height given weight, y ~ N (a + Bx,o7). (a) Prior predictive 
samples using a Gaussian prior for B. (b) Prior predictive samples using a Log-Gaussian prior for B. (c) 
Posterior predictive samples using the Log-Gaussian prior. The inner shaded band is the 95% credible 
interval for u, representing epistemic uncertainty. The outer shaded band is the 95% credible interval for the 
observations y, which also adds data uncertainty due to o. Adapted from Figures 4.5 and 4.10 of [McE20]. 
Generated by linreg_height_ weight. ipynb. 


With this prior, the posterior becomes 


p(w,o7|g,D) = NIG(w, o?|wn, Vn, an, bw) (15.73) 
Veo (XTX) (15.74) 
g+ 
wy = ðm (15.75) 
g+1 
an = N/2 (15.76) 
8? 1 T 
=> +- ól pX] X Üne 15. 
bn 7 +F 2g + Tj me Wml (15.77) 


Various approaches have been proposed for setting g, including cross validation, empirical Bayes 
[Min00a; GF00], hierarchical Bayes [Lia+08], etc. 


34 15.2.3 Informative priors 


In many problems, it is possible to use domain knowledge to come up with plausible priors. As an 
example, we consider the problem of predicting the height of a person given their weight. We will 
use a dataset collected from Kalahari foragers by the anthropologist Nancy Howell (this example is 
from the book “Statistical Rethinking” [McE20, p93]). 

Let x; be the weight (in kg) and y; be height (in cm) of the ith person, and let x be the mean of 


the inputs. The observation model is given by 
yi ~ N (ui, 7) (15.78) 
li = a+ blz; — T) (15.79) 


45 We see that the intercept a is the predicted output if x; = Z, and the slope 8 is the predicted change 
46 in height per unit change in weight above or below the average weight. 
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s A 
6.2 20.9 35.6 50.3 65.0 6.2 20.9 35.6 50.3 65.0 6.2 20.9 35.6 50.3 65.0 
weight weight weight 


(a) (b) (c) 


Figure 15.2: Linear regression for predicting height given weight for the full dataset (including children) using 
polynomial regression. (a) Posterior fit for linear model with log-Gaussian prior for Bı. (b) Posterior fit for 
quadratic model with log-Gaussian prior for B2. (c) Posterior fit for quadratic model with Gaussian prior for 
b2. Adapted from Figure 4.11 of [McE20]. Generated by linreg_height_ weight.ipynb. 


The question is: what priors should we use? To be truly Bayesian, we should set these before 
looking at the data. A sensible prior for œ is the height of a “typical person”, with some spread. We 
use a ~ N(178, 20), since the author of the “Rethinking Statistics” book from which this example is 
taken is 178cm. By using a standard deviation of 20, the prior puts 95% probability on the broad 
range of 178 + 40. 

What about the prior for 8? It is tempting to use a vague prior, or weak prior, such as 
B ~ N(0,10), which is similar to a flat (uniform) prior, but more concentrated at 0 (a form of mild 
regularization). To see if this is reasonable, we can compute samples from the prior predictive 
distribution, i.e., we sample (as, 8s) ~ p(a)p(G), and then plot a,x + 8s for a range of x values, 
for different samples s = 1: S. The results are shown in Figure 15.la. We see that this is not a very 
sensible prior. For example, we see that it suggests that it is just as likely for the height to decrease 
with weight as increase with weight, which is not plausible. In addition, it predicts heights which 
are larger than the world’s tallest person (272 cm) and smaller than the world’s shortest person (an 
embryo, of size 0). 

We can encode the monotonically increasing relationship between weight and height by restricting 
B to be positive. An easy way to do this is to use a log-normal or log-Gaussian prior. (If 8 = log(8) 


is Gaussian, then ef must be positive.) Specifically, we will assume 6 ~ LN (0,1). Samples from this 
prior are shown in Figure 15.1b. This is much more reasonable. 

Finally we must choose a prior over ø. In [McE20] they use ø ~ Unif(0,50). This ensures that o 
is positive, and that the prior predictive distribution for the output is within 100cm of the average 
height. However, it is usually easier to specify the expected value for o than an upper bound. To 
do this, we can use ø ~ Expon(A), where A is the rate. We then set E [øo] = 1/A to the value of the 
standard deviation that we expect. For example, we can use the empirical standard deviation of the 
data. 

Since these priors are no longer conjugate, we cannot compute the posterior in closed form. However, 
we can use a variety of approximate inference methods. In this simple example, it suffices to use a 
quadratic (Laplace) approximation (see Section 7.4.3). The results are shown in Figure 15.1c, and 
look sensible. 

So far, we have only considered a subset of the data, corresponding to adults over the age of 18. If 
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we include children, we find that the mapping from weight to height is nonlinear. This is illustrated 
in Figure 15.2a. We can fix this problem by using polynomial regression. For example, consider a 
quadratic expansion of the standardized features 2;: 


Hi = a+ biti + Boa? (15.80) 


If we use a log-Gaussian prior for (2, we find that the model is too constrained, and it underfits. 
This is illustrated in Figure 15.2b. The reason is that we need to use an inverted quadratic with 
a negative coefficient, but since this is disallowed by the prior, the model ends up not using this 
degree of freedom (we find E [62|D] ~ 0.08). If we use a Gaussian prior on (2, we avoid this problem, 
illustrated in Figure 15.2c. 

This example shows that it can be useful to think about the functional form of the mapping from 
inputs to outputs in order to specify sensible priors. 


15.2.4 Spike and slab prior 


It is often useful to be able to select a subset of the input features when performing prediction, either 
to reduce overfitting, or to improve interpretability of the model. This can be achieved if we ensure 
that the weight vector w is sparse (i.e., has many zero elements), since if wa = 0, then aq plays no 
role in the inner product w' x. 

The canonical way to achieve sparsity when using Bayesian inference is to use a spike-and-slab 
(SS) prior [MB88], which has the form of a 2 component mixture model, with one component being 
a “spike” at 0, and the other being a uniform “slab” between —a and a: 


D 
p(w) = | [C — m) lwa) + rUnif (wal — a, a) (15.81) 
d=1 


where 7 is the prior probability that each coefficient is non-zero. The corresponding log prior on the 


— coefficients is thus 
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log p(w) = ||w||o log(1 — r) + (D — ||w]|o) log = —A]|w]|o + const (15.82) 


where \ = log = controls the sparsity of the model, and ||w||o = nwt I (wa # 0) is the £o norm 
of the weights. Thus MAP estimation with a spike and slab prior is equivalent /ọ regularization; 
this penalizes the number of non-zero coefficients. Interestingly, posterior samples will also be sparse. 

By contrast, consider using a a Laplace prior. The lasso estimator uses MAP estimation, which 
results in a sparse estimate. However, posterior samples are not sparse. Interestingly, [EY09] show 
theoretically (and [SPZ09] confirm experimentally) that using the posterior mean with a spike-and- 
slab prior also results in better prediction accuracy than using the posterior mode with a Laplace 


prior. 
In practice, we often approximate the uniform slab with a broad Gaussian distribution, 
p(w) = [ [0 — 2)5(wa) + rN (wal, o2) (15.83) 
d 


As 02, + oo, the second term approaches a uniform distribution over [—oo, +00]. We can implement 


45 the mixture model by associating a binary random variable, są ~ Ber(z), with each coefficient, to 
46 indicate if the coefficient is “on” or “off”. 
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(a) (b) 


Figure 15.3: (a) Representing lasso using a Gaussian scale mixture prior. (b) Graphical model for group lasso 
with 2 groups, the first has size Gi = 2, the second has size G2 = 3. 


Unfortunately, MAP estimation (not to mention full Bayesian inference) with such discrete mixture 
priors is computationally difficult. Various approximate inference methods have been proposed, 
including greedy search (see e.g., [SPZ09]) or MCMC (see e.g., [HS09]). 


15.2.5 Laplace prior (Bayesian lasso) 


A computationally cheap way to achieve sparsity is to perform MAP estimation with a Laplace prior 
by minimizing the penalized negative log likelihood: 


PNLL(w) = — log p(D|w) — log p(w|A) = |[Xw — yl|3 + Allwl|i (15.84) 


where ||w||1 + Sy |wa| is the 2; norm of w. This method is called lasso, which stands for “least 
absolute shrinkage and selection operator” [Tib96]. See Section 11.4 of the prequel to this book, 
[Mur22], for details. 

In this section, we discuss posterior inference with this prior; this is known as the Bayesian lasso 
[PCO8]. In particular, we assume the following prior: 


A 2 
2 —Alw,;|/Vo? 
wo ) = l | — e i 15.85 


(Note that conditioning the prior on g? is important to ensure that the full posterior is unimodal.) 


To simplify inference, we will represent the Laplace prior as a Gaussian scale mixture, which we 
discussed in Section 28.2.3.2. In particular, one can show that the Laplace distribution is an infinite 
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weighted sum of Gaussians, where the precision comes from a Gamma distribution: 


2 


Laplace(w|0, A) = J Nelo, Gan, a )dr? (15.86) 


2 
We can therefore represent the Bayesian lasso model as a hierarchical latent variable model, as shown 
in Figure 15.3a. The corresponding joint distribution has the following form: 


p(y, w, T, 0°|X) =N(y|Xw,o7Ly) | TM (w;|0, 0777) Ga(r7|1, à?/2)| po?) (15.87) 
j 


We can also create a GSM to match the group lasso prior, which sets multiple coefficients to 
zero at the same time: 


wgjo?, T? ~ N(0,0772Ta,) (15.88) 
dj +1 X 
72 ~ Gal = >) (15.89) 


where dy is the size of group g. So we see that there is one variance term per group, each of which 
comes from a Gamma prior, whose shape parameter depends on the group size, and whose rate 
parameter is controlled by y. 

Figure 15.3b gives an example, where we have 2 groups, one of size 2 and one of size 3. This picture 
makes it clearer why there should be a grouping effect. For example, suppose w ,; is small; then 7? 
will be estimated to be small, which will force wi,2 to be small, due to shrinkage (c.f., Section 3.7). 


26 Conversely, suppose w1, is large; then TÊ? will be estimated to be large, which will allow W1,2 to be 
27 become large as well. 


Given these hierachical models, we can easily derive a Gibbs sampling algorithm (Section 12.3) to 


29 sample from the posterior (see e.g., [PC08]). Unfortunately, these posterior samples are not sparse, 
30 even though the MAP estimate is sparse. This is because the prior puts infinitessimal probability on 


31 the event that each coefficient is zero. 


~— 15.2.6 Horseshoe prior 


35 The Laplace prior is not suitable for sparse Bayesian models, because posterior samples are not 
36 sparse. The spike and slab prior does not have this problem but is often too slow to use (although see 
37 [BRG20]). Fortunately, it is possible to devise continuous priors (without discrete latent variables) 
38 that are both sparse and computationally efficient. One popular prior of this type is the horseshoe 
39 prior [CPS10], so-named because of the shape of its density function. 


In the horseshoe prior, instead of using a Laplace prior for each weight, we use the following 


41 Gaussian scale mixture: 


wj ~ N(0, A477) (15.90) 
rj ~ C+(0, 1) (15.91) 
T°? ~ C4 (0,1) (15.92) 
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where C;(0,1) is the half-Cauchy distribution (Section 2.2.2.4), A; is a local shrinkage factor, and 7? 
is a global shrinkage factor. The Cauchy distribution has very fat tails, so A; is likely to be either 
0 or very far from 0, which emulates the spike and slab prior, but in a continuous way. For more 
details, see e.g., [Bha+19]. 


15.2.7 Automatic relevancy determination 


An alternative to using posterior inference with a sparsity promoting prior is to use posterior inference 
with a Gaussian prior, w; ~ N(0,1/a,;), but where we use empirical Bayes to optimize the precisions 
aj. That is, we first compute & = argmax, p(y|X, a), and then compute Ù = argmax,, V(w|0,@7'). 
Perhaps surprisingly, we will see that this results in a sparse estimate, for reasons we explain in 
Section 15.2.7.2. 

This technique is known as sparse Bayesian learning [Tip01| or automatic relevancy de- 
termination (ARD) [Mac95; Nea96]. It has also been called NUV estimation, which stands for 
“normal prior with unknown variance” [Loe+16]. It was originally developed for neural networks 
(where sparsity is applied to the first layer weights), but here we apply it to linear models. 


15.2.7.1 ARD for linear models 


In this section, we explain ARD in more detail, by applying it to linear regression. The likelihood 
is p(y|x,w, 8) = N(y|w'ax,1/8), where 8 = 1/o?. The prior is p(w) = N(w|0, A‘), where 
A = diag(@). The marginal likelihood can be computed analytically (using Equation (2.90)) as 
follows: 


p(ylX,a,8) = f N(ylXw, (1/8)Lvp)A (wl, A?) dw (15.93) 
= N(y|0,8-1In, + XA!X') (15.94) 
= N(y|0, Ca) (15.95) 


where Ca £ 6-!Iy + XA7!X!. This is very similar to the marginal likelihood under the spike-and- 
slab prior (Section 15.2.4), which is given by 


p(ulX, 8, 02,02) = J NWX sws, ?I)N (ws]0s,02I)dws = N(y]0, Cs) (15.96) 


where Cs = ofIyp + o2,X,X!. (Here X, refers to the design matrix where we select only the 
columns of X where sq = 1.) The difference is that we have replaced the binary sj € {0,1} variables 
with continuous a; € Rt, which makes the optimization problem easier. 

The objective is the log marginal likelihood, given by 


1 
tla, b) = -3 log p(y|X, a, 8) = log|Cal + y'Co'y (15.97) 


There are various algorithms for optimizing (aœ, 8), some of which we discuss in Section 15.2.7.3. 

ARD can be used as an alternative to 41 regularization. Although the ARD objective is not convex, 
it tends to give much sparser results [WW12]. In addition, it can be shown [WRN10] that the ARD 
objective has many fewer local optima than the @o-regularized objective, and hence is much easier to 
optimize. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Io e low IN Ie 


IR IR IS IS IRIS Ie Ie IR le lale le Is IE ls 


574 
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(a) (b) 


Figure 15.4: Illustration of why ARD results in sparsity. The vector of inputs x does not point towards the 
vector of outputs y, so the feature should be removed. (a) For finite a, the probability density is spread in 
directions away from y. (b) When a = ov, the probability density at y is maximized. Adapted from Figure 8 
of [Tip01]. 


15.2.7.2 Why does ARD result in a sparse solution? 


Once we have estimated œ and 8, we can compute the posterior over the parameters using Bayes 
rule for Gaussians, to get p(w|D, &, 8) =N(w| ®, £), where Ss = BX™X + A and ®= 8 Ê X'y. 
If we have âa & oo, then a7 0, so the solution vector will be sparse. 

We now give an intuitive argument, based on [Tip01], about when such a sparse solution may be 
optimal. We shall assume 6 = 1/o? is fixed for simplicity. Consider a 1d linear regression with 2 
training examples, so X = æ = (21,22), and y = (y1, y2). We can plot x and y as vectors in the 
plane, as shown in Figure 15.4. Suppose the feature is irrelevant for predicting the response, so x 
points in a nearly orthogonal direction to y. Let us see what happens to the marginal likelihood as we 


change a. The marginal likelihood is given by p(y|x, a, 8) = N(y|0,C.), where Ca = zl + tea’. 


31 If a is finite, the posterior will be elongated along the direction of æ, as in Figure 15.4(a). However, 


if a = œ, we have Cy = al, which is spherical, as in Figure 15.4(b). If |Ca| is held constant, the 
latter assigns higher probability density to the observed response vector y, so this is the preferred 
solution. In other words, the marginal likelihood “punishes” solutions where aq is small but X. q is 
irrelevant, since these waste probability mass. It is more parsimonious (from the point of view of 
Bayesian Occam’s razor) to eliminate redundant dimensions. 

Another way to understand the sparsity properties of ARD is as approximate inference in a 
hierarchical Bayesian model [BT00]. In particular, suppose we put a conjugate prior on each 
precision, wg ~ Ga(a, b), and on the observation precision, 6 ~ Ga(c, d). Since exact inference with a 
Student prior is intractable, we can use variational Bayes (Section 10.2.3), with a factored posterior 
approximation of the form 


q(w, a) = q(w)q(a) ~ N (wu, £) | [ Galaal ĉa, ba) (15.98) 
d 


46 ARD approximates q(@) by a point estimate. However, in VB, we integrate out a; the resulting 
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posterior marginal q(w) on the weights is given by 
pup = J Nilo, diaga) =) [] Galaal 2a, B)de (15.99) 
d 


This is a Gaussian scale mixture, and can be shown to be the same as a multivariate Student 
distribution (see Section 28.2.3.1), with non-diagonal covariance. Note that the Student has a large 
spike at 0, which intuitively explains why the posterior mean (which, for a Student distribution, is 
equal to the posterior mode) is sparse. 

Finally, we can also view ARD as a MAP estimation problem with a non-factorial prior [WNO7]. 
Intuitively, the dependence between the w; parameters arises, despite the use of a diagonal Gaussian 
prior, because the prior precision a; is estimated based after marginalizing out all w, and hence 
depends on all the features. Interestingly, [WRN10] prove that MAP estimation with non-factorial 
priors is strictly better than MAP estimation with any possible factorial prior in the following 
sense: the non-factorial objective always has fewer local minima than factorial objectives, while still 
satisfying the property that the global optimum of the non-factorial objective corresponds to the 
global optimum of the Zo objective — a property that ¢; regularization, which has no local minima, 
does not enjoy. 


15.2.7.3 Algorithms for ARD 


There are various algorithms for optimizing (a, 3). One approach is to use EM, in which we compute 
p(w|D, a) in the E step and then maximize a in the M step. In variational Bayes, we infer both w 
and a (see [Dru08] for details). In [WN10], they present a method based on iteratively reweighted £1 
estimation. 

Recently, [HX W17] showed that the nested iterative computations performed these methods can 
emulated by a recurrent neural network (Section 16.3.4). Furthermore, by training this model, it is 
possible to achieve much faster convergence than manually designed optimization algorithms. 


15.2.7.4 Relevance vector machines 


Suppose we create a linear regression model of the form p(y|a;@) = N(y|w'o(a),07), where 
p(x) = [K(xz,£1),..., K(x, £y )], where K() is a kernel function (Section 18.2) and a1,...,a@, are 
the N training points. This is called kernel basis function expansion, and transforms the input 
from g € X to d(x) € R. Obviously this model has O(N) parameters, and hence is nonparametric. 
However, we can use ARD to select a small subset of the exemplars. This technique is called the 
relevance vector machine (RVM) [Tip01; TF03]. 


15.2.8 Multivariate linear regression 

This section is written by Xinglong Li. 
In this section, we consider the multivariate linear regression model, which has the form 
Y=WX+E (15.100) 


where W e R%»* is the matrix of regression coefficient, X € RN**N? is the matrix of input 
features (with each row being an input variable and each column being an observation), Y € RN» xN» 
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is the matrix of responses (with each row being an output variable and each column being an 
observation), and E = [e1,- +- ,en,] is the matrix of residual errors, where e; tap N (0, £). It can be 
seen from the definition that given X, W and X, columns of Y are independently random variables 
following multivariate normal distributions. So the likelihood of the observation is 


No 
= 1 1 Ty-1 
p(Y|W, X, £) = Gn) NANa ajnon P (Š — 5 (yi — Wai)" E} (y; — wav) (15.101) 
1 1 Ty-1 
E (27) NuX Nv || Nv/2 exp (— trace (Y - WX) © (Y- wx))) (15.102) 
= MN (Y|WX, x, InpxNp)s (15.103) 


The conjugate prior for this is the matrix normal inverse Wishart distribution, 
W, X~ MNIW (Mo, Vo, Vo, Wo) (15.104) 
where the MNIW is defined by 


WIJE ~ MN (Mo, Zo, Vo) (15.105) 
sw IW(%, Wo), (15.106) 


where Vo € RIE N: 2 Woe Rie Nv and vo > Nz — 1 is the degree of freedom of the inverse Wishart 
distribution. 

The posterior distribution of {W, X} still follows a matrix normal inverse Wishart distribution. 
We follow the derivation in [Fox09, App.F]. Firstly, the density of the joint distribution is 


1 
p(Y, W, £) |E 700 tN tI+N:+ND)/2 x exp {—strace(s%)} (15.107) 


OQ £ PoE! + (Y — WX)'E (Y — WX) + (W -— Mo)! S71 (W — Mo) Vo 
(15.108) 


We firstly aggregate items including W in the exponent so that it takes the form of a matrix normal 


distribution. This is similar to the “completing the square” technique that we used in deriving the 

conjugate posterior for multivariate normal distributions in Section 3.3.3.3. Specifically, 
trace (Y — WX)'=~'(Y — WX) + (W — Mo) ' £~! (W — Mo) Vo| (15.109) 
=trace (X7*[(Y — WX)(Y — WX)! + (W — Mo) Vo(W — Mo)"]) (15.110) 
=trace (=~ *[WS,,,.W' — 28,,.W" + Syy]) (15.111) 
=trace (E*[(W — SycSzz)Sxc(W — SycSz2)' + Syel) - (15.112) 

— where 

Sze = XX! + Vo, Syz = YX! +MoVo, (15.113) 
Syy = YY' + Mo VoM}, Sylo = Syy — Syr Sza Siz. (15.114) 


A 16 
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Therefore, it can be see from Equation (15.112) that given ©, W follows a matrix normal distribution 


WIE, X,Y ~ MN (8,283, X, Sze). (15.115) 


TT? 


Marginalizing out W (which corresponds to removing the terms including W in the exponent 
in Equation (15.108)), it can be shown that the posterior distribution of © is an inverse Wishart 
distribution. In fact, by replacing Equation (15.112) to the corresponding terms in Equation (15.108), 
it can be seen that the only terms left after integrating out W are ©~'W and Sie, which 
indicates that the scale matrix of the posterior inverse Wishart distribution is Wo + Syz- 

In conclusion, the joint posterior distribution of {W, =} given the observation is 


W, |X, Y ~ MNIW(Mi, Vi, 1, V1) (15.116) 
Mi = 8,28) (15.117) 

Vi = Soe (15.118) 

vı = Np + v (15.119) 

Ti = Po + Syn (15.120) 


The MAP estimate of W and © are the mode of the posterior matrix normal inverse Wishart 
distribution. To derive this, notice that W only appears in the matrix normal density function in 
the posterior, so the matrix W maximizing the posterior density of {W, ©} is the matrix W that 
maximizes the matrix normal posterior of W. So the MAP estimate of W is W = M; = Bye Sinus 
and this holds for any value of ©. By plugging W = W into the joint posterior of {W, ©}, and taking 
derivatives over X, it can be seen that the matrix maximizing the density is (v1 + Ny + Ne +1)7'Wr. 
Since WV, is positive definite, it is the MAP estimate of X. 

In conclusion, the MAP estimate of {W, ©} are 


W = S, S34 (15.121) 
1 
vi +Ny+Ns+1 


y= (Wo + Sy) (15.122) 


15.3 Logistic regression 


Logistic regression is a very widely used discriminative classification model that maps input 
vectors x € RP to a distribution over class labels, y € {1,...,C}. If C = 2, this is known as 
binary logistic regression, and if C > 2, it is known as multinomial logistic regression, or 
alternatively, multiclass logistic regression. 


15.3.1 Binary logistic regression 


In the binary case, where y € {0,1}, the model has the following form 


p(yl|a; 0) = Ber(y|o(w'x + b)) (15.123) 

where w are the weights, b is the bias (offset), and ø is the sigmoid or logistic function, defined by 
N 1 

& 15.124 

ola) ê (15.124) 
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Let n = w'a, +b be the logits for example n, and un = o (m) = p(y = lan) be the mean of 
the output. Then we can write the log likelihood as the negative cross entropy: 


N N 
log p(D|@) = log | | u% (1 — tn) = $. yn log un + (1 — yn) log(1 — pn) (15.125) 
n=1 n=1 


We can expand this equation into a more explicit form (that is commonly seen in implementations) 
by performing some simple algebra. First note that 


1 el en 1 
= = eee = 15.126 
j 1+e7™ 1+ e%™ 5 1 + enr 1 + en ( ) 
Hence 

N 

log p(D|0) = 2 vn [log e”” — log(1 + e™)] + (1 — yn) [log 1 — log(1 + e”” )] (15.127) 
v 

= Y” yaltn — log(1 + €)] + (1 — vn)[- log (1 +.) (15.128) 
=1 

N N 


= 3 Ynn — X log(1 + e™) (15.129) 
n=1 


Note that the log(1 + e°) function is often implemented using np.log1p(np.exp(a)). 


15.3.2 Multinomial logistic regression 


Multinomial logistic regression is a discriminative classification model of the following form: 
p(y|x; 0) = Cat(y|softmax(W-z + b)) (15.130) 


where æ € R? is the input vector, y € {1,...,C} is the class label, W is a C x D weight matrix, b 
is C-dimensional bias vector, and softmax() is the softmax function, defined as 


ftmax(a) 4 = = (15.131) 
soitmax(a) = ae ; 
(ej aw ; 7 Cc alt 
azi ee Xai Cue 


If we define the logits as n, = Wa, + b, the probabilities as u„ = softmax(n,,), and let yn be 
the one-hot encoding of the label yn, then the log likelihood can be written as as the negative cross 
entropy: 


N C 
log p(D|9) = log II [I uie = J $ une log Hne (15.132) 


n=l1c=1 n=1c=1 
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1 prediction, 5 prediction, 15 prediction, 
variance=1.50 std=1.50 std=1.50 std=1.50 


variance=10.00 


w 


N 


j 
í 


(b) 


Figure 15.5: (a) Prior on logistic regression output when using N (0,w) prior for the offset term, for w = 10 
or w = 1.5. Adapted from Figure 11.3 of [McE20]. Generated by logreg_ prior_offset.ipynb. (b) Distribution 
over the fraction of 1s we expect to see when using binary logistic regression applied to random binary feature 
vectors of increasing dimensionality. We use a N(0,1.5) prior on the regression coefficients. Adapted from 
Figure 3 of [Gel+20]. Generated by logreg_ prior.ipynb. 


15.3.3 Priors 


As with linear regression, it is standard to use Gaussian priors for the weights in a logistic regression 
model. It is natural to set the prior mean to 0, to reflect the fact that the output could either 
increase or decrease in probability depending on the input. But how do we set the prior variance? It 
is tempting to use a large value, to approximate a uniform distribution, but this is a bad idea. To 
see why, consider a binary logistic regression model with just an offset term and no features: 


p(y|@) = Ber(y|o(a)) (15.133) 
p(a) = N(al0, w) (15.134) 


If we set the prior to the large value of w = 10, the implied prior for y is an extreme distribution, 
with most of its density near 0 or 1, as shown in Figure 15.5a. By contrast, if we use the smaller 
value of w = 1.5, we get a flatter distribution, as shown. 

If we have input features, the problem gets a little trickier, since the magnitude of the logits will 
now depend on the number and distribution of the input variables. For example, suppose we generate 
N random binary vectors £n, each of dimension D, where zna ~ Ber(p), where p = 0.8. We then 
compute plyn = lan) = o(B' an), where 3 ~ N(0,1.51). We sample S' values of 3, and for each 
one, we sample a vector of labels, y1:N,s from the above distribution. We then compute the fraction 
of positive labels, fs = $ Spann I (Yn,s = 1). We plot the distribution of {fs} as a function of D in 
Figure 15.5b. We see that the induced prior is initially flat, but eventually becomes skewed towards 
the extreme values of 0 and 1. To avoid this, we should standardize the inputs, and scale the variance 
of the prior by 1/ VD. We can also use a heavier tailed distribution, such as a Cauchy or Student 
[Gel+-08; GLM15]. 
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15.3.4 Posteriors 


Unfortunately, there is no tractable prior that is conjugate to the logistic likelihood. Hence we cannot 
compute the posterior analytically, unlike with linear regression, even if we use a Gaussian prior. 
(This mirrors the case with MLE, where we have a closed form solution for linear regression, but not 
for logistic regression.) Fortunately, there are a range of approximate inference methods we can use, 
as we discuss in the sections below. 


15.3.5 Laplace approximation 


As we discuss in Section 7.4.3, the Laplace approximation approximates the posterior using a Gaussian. 
The mean of the Gaussian is equal to the MAP estimate w, and the covariance is equal to the inverse 
Hessian H computed at the MAP estimate, i.e., p(w|D) ~ N(w|w,H), We can find the mode using 
a standard optimization method, and we can then compute the Hessian at the mode analytically or 
using automatic differentiation. 

As an example, consider the binary data illustrated in Figure 15.6(a). There are many parameter 
settings that correspond to lines that perfectly separate the training data; we show 4 example lines. 
For each decision boundary in Figure 15.6(a), we plot the corresponding parameter vector as point 
in the log likelihood surface in Figure 15.6(b). These parameters values are wı = (3,1), we = (4,2), 
ws; = (5,3), and wy = (7,3). These points all approximately satisfy w;(1)/w;(2) ~ Wmie(1)/Wmie(2), 
and hence are close to the orientation of the maximum likelihood decision boundary. The points 
are ordered by increasing weight norm (3.16, 4.47, 5.83, and 7.62). The unconstrained MLE has 
||w|| = œœ, so is infinitely far to the top right. 

To ensure a unique solution, we use a (spherical) Gaussian prior centered at the origin, N(w|0, 071). 
The value of ø? controls the strength of the prior. If we set 7? = oo, we force the MAP estimate 


= to be w = 0; this will result in maximally uncertain predictions, since all points æ will produce a 
=“ predictive distribution of the form p(y = 1|æ) = 0.5. If we set o? = 0, the MAP estimate becomes 
= the MLE, resulting in minimally uncertain predictions. (In particular, all positively labeled points 
= will have p(y = 1|a) = 1.0, and all negatively labeled points will have p(y = 1|x) = 0.0, since the 
= data is separable.) As a compromise (to make a nice illustration), we pick the value ø? = 100. 


Multiplying this prior by the likelihood results in the unnormalized posterior shown in Figure 15.6(c). 


“= The MAP estimate is shown by the blue dot. The Laplace approximation to this posterior is shown 
= in Figure 15.6(d). We see that it gets the mode correct (by construction), but the shape of the 
= posterior is somewhat distorted. (The southwest-northeast orientation captures uncertainty about the 
°° magnitude of w, and the southeast-northwest orientation captures uncertainty about the orientation 


= of the decision boundary.) 
Next we need to convert the posterior over the parameters into a posterior over predictions, as 
= follows: 
piole, D) = | plyle, w)p(w|D)dw (15.135) 
The simplest way to evaluate this integral is to use a Monte Carlo approximation: 
1 S 
ply = Ya, D) ~ z S 5 o(wia) (15.136) 
s=1 
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(a) Data (b) Log likelihood 
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—1600 Ba 
1920 aa 
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(c) Log-unnormalised posterior (d) Laplace approximation to posterior 


Figure 15.6: (a) Illustration of the data and some decision boundaries, where colors correspond to the points in 
the (b) panel. (b) Log-likelihood for a logistic regression model. The line is drawn from the origin in the direction 
of the MLE (which is at infinity). The numbers correspond to 4 points in parameter space, corresponding to 
the lines in (a). (c) Unnormalized log posterior (assuming vague spherical prior). (d) Laplace approximation 
to posterior. Adapted from a figure by Mark Girolami. Generated by logreg_laplace_ demo.ipynb. 


where ws ~ p(w|D). 


Alternatively, we can use the deterministic probit approxmation first suggested in [SL90]. If we 


define a = x'w, and p(w|D) = N(w|u, £), then we can write this approximation as 


ply = 1\x,D) x o(K(v)m) (15.137) 
K(v) & (1 + rv/8)72 (15.138) 

v= V [a] = V [x'w] =2' £x (15.139) 

m =E ja] = æ' u (15.140) 


In Figure 15.7, we show contours of the posterior predictive distribution. Figure 15.7 (a) shows the 
plugin approximation using the MAP estimate. We see that there is no uncertainty about the decision 
boundary, even though we are generating probabilistic predictions over the labels. Figure 15.7(b) 
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(a) ply = 1x, wmap) (b) Decision boundary for sampled w 
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(c) MC approx of p(y = 1|z) (d) Deterministic approx of p(y = 1|x) 


Figure 15.7: Posterior predictive distribution for a logistic regression model in 2d. (a): contours of p(y = 
1|x,Wmap). (b): samples from the posterior predictive distribution. (c): Averaging over these samples. 
=" (d): moderated output (probit approximation). Adapted from a figure by Mark Girolami. Generated by 
logreg_laplace_ demo.ipynb. 


shows what happens when we plug in samples from the Gaussian posterior. Now we see that there is 


34 considerable uncertainty about the orientation of the “best” decision boundary. Figure 15.7(c) shows 


the average of these samples. By averaging over multiple predictions, we see that the uncertainty in 
the decision boundary “splays out” as we move further from the training data. Figure 15.7(d) shows 


37 that the probit approximation gives very similar results to the Monte Carlo approximation. 


15.3.6 MCMC inference 


Markov chain Monte Carlo, or MCMC, is often considered the “gold standard” for approximate 
inference, since it makes no explicit assumptions about the form of the posterior. Instead, it just 
approximates it non-parametrically using a set of S samples: 


q(O|D) ~ $y (0 — 6°) (15.141) 
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Figure 15.8: Illustration of the posterior over the decision boundary for classifying iris flowers (setosa 
vs versicolor) using 2 input features. (a) 25 examples per class. Adapted from Figure 4.5 of [Mar18]. 
(b) 5 examples of class 0, 45 examples of class 1. Adapted from Figure 4.8 of [Mar18]. Generated by 
logreg_iris_ bayes_ 2d.ipynb. 


where 6° ~ p(6|D) are samples from the posterior. 

To efficiently compute these samples, we can use the method of Hamiltonian Monte Carlo or HMC, 
which we describe in Section 12.5. This relies on our ability to compute the gradient of the log joint, 
Vo log p(D, 8), which we can compute using automatic differentiation. 

Let us apply HMC to a 2-dimensional, 2-class version of the iris classification problem, where we 
just use two input features, sepal length and sepal width, and two classes, Virginica and Non-Virginica. 
The decision boundary is the set of points (aj, x3) such that o(b + wiz} + w2x3) = 0.5. Such points 
must lie on the following line: 


p= (-a;) (15.142) 


We can therefore compute an MC approximation to the posterior over decision boundaries by sampling 
the parameters from the posterior, (w1, w2, b) ~ p(@|D), and plugging them into the above equation, 
to get p(x{, £2 x |D). The results of this method (using a vague Gaussian prior for the parameters) 
are shown in Figure 15.8a. The solid line is the posterior mean, and the shaded interval is a 95% 
credible interval. As before, we see that the uncertainty about the location of the boundary is higher 
as we move away from the training data. 

In Figure 15.8b, we show what happens to the decision boundary when we have unbalanced classes. 
We notice two things. First, the posterior uncertainty increases, because we have less data from the 
blue class. Second, we see that the posterior mean of the decision boundary shifts towards the class 
with less data. This follows from linear discriminant analysis, where one can show that changing 
the class prior changes the location of the decision boundary, so that more of the input space gets 
mapped to the class which is higher a priori. (See [Mur22, Sec 9.2] for details.) 
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Figure 15.9: The logistic (sigmoid) function o(x) in solid red, with the Gaussian cdf function ®(Ax) in dotted 
blue superimposed. Here AX = \/7/8, which was chosen so that the derivatives of the two curves match at 
x =0. Adapted from Figure 4.9 of [Bis06]. Generated by probit_ plot.ipynb. 


15.3.7 Variational inference 


As we discuss in Section 10.1, variational inference converts approximate inference into an optimization 
problem. It does this by choosing an approximate distribution g(w;w) and optimiing the variational 
parameters wy to maximize the evidence lower bound (ELBO). This has the effect of making 
qlw; Y) ~ p(w|D) in the sense that the KL divergence is small. There are several ways to tackle 
this: use a stochastic estimate of the ELBO (see Section 10.3.3), use the conditionally conjugate VI 
method of Supplementary Section 10.3.1.2, or use a “local” VI method that creates a quadratic lower 
bound to the logistic function (see Supplementary Section 15.1). 


15.3.8 Assumed density filtering 


In Section 8.9.3, we discuss how to use assumed density filtering (ADF) to recursively approximate 


29 the posterior p(w|D,.,) for a logistic regression model in an online fashion. 


+ 15.4 Probit regression 


33 Tn this section, we discuss probit regression, which is similiar to binary logistic regression except 
34 it uses Un = (an) instead of un = o (an) as the mean function, where ® is the cdf of the standard 


normal, and a, = w' æn. The corresponding link function is therefore an = l( un) = P71 (un); the 


36 inverse of the Gaussian cdf is known as the probit function. 


The Gaussian cdf ® is very similar to the logistic function, as shown in Figure 15.9. Thus probit 


38 regression and “regular” logistic regression behave very similarly. However, probit regression has some 
39 advantages. In particular, it has a simple interpretation as a latent variable model (see Section 15.4.1), 
40 which arises from the field of choice theory as studied in economics (see e.g., [Koo03]). This also 
* simplifies the task of Bayesian parameter inference. 


— 15.4.1 Latent variable interpretation 


Tæp as a factor that is proportional to how likely a person is respond 
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influence someone’s response. Let us model these hidden factors by Gaussian noise, €n ~ (0,1). Let 
the combined preference for positive outcomes be represented by the latent variable zn = W £n + En. 
We assume that the person will pick the positive label iff this latent factor is positive rather than 
negative, i.e., 


Yn = I (zn 2 0) (15.143) 


When we marginalize out zn, we recover the probit model: 


Plyn = lan, w) = fie > 0) N(zn|w' an, 1)dzn (15.144) 
= p(w En + €n > 0) = plen > —w'an) (15.145) 
=1-6(-w'z,) = P(w rn) (15.146) 


Thus we can think of probit regression as a threshold function applied to noisy input. 
We can interpret logistic regression in the same way. However, in that case the noise term €, 
comes from a logistic distribution, defined as follows: 


oln) * —— (15.147) 


The cdf of this distribution is given by 


1 
F(yļu, s) = —~ = (15.148) 
l+e “^ 


It is clear that if we use logistic noise with u = 0 and s = 1 we recover logistic regression. However, 
it is computationally easier to deal with Gaussian noise, as we show below. 
15.4.2 Maximum likelihood estimation 


In this section, we discuss some methods for fitting probit regression using MLE. 


15.4.2.1 MULE using SGD 


We can find the MLE for probit regression using standard gradient methods. Let un = w ap, and 
let Gn E {—1, +1}. Then the gradient of the log-likelihood for a single example n is given by 


Yn (En) 
P (ÜnHn) 


d dun d . 
de log p(Gn|w' £n) = wea log p(Gn|w' en) = Trn 


A 


In (15.149) 


where ¢ is the standard normal pdf, and © is its cdf. Similarly, the Hessian for a single case is given 
by 


H, 


log p(Gn|w" æn) = 2. ( Pin) -Pakaba ) T 


7 H re ti 15.150 
®(Gnbn)? B(Gn bn) ( ) 


This can be passed to any gradient-based optimizer. 


~ dw? 
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probit regression with L2 regularizer of 0.010 


— EM 
= BEGS 


Negative log-likelihood 
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Figure 15.10: Fitting a probit regression model in 2d using a quasi-Newton method or EM. Generated by 


probit reg_ demo.ipynb. 


15.4.2.2 MLE using EM 


We can use the latent variable interpretation of probit regression to derive an elegant EM algorithm 
for fitting the model. The complete data log likelihood has the following form, assuming a M (0, Vo) 


prior on w: 

£(z, w|Vo) = log p(y|z) + log N(z|Xw, I) + log N(w]|0, Vo) 
1 
= y log p(Yn|Zn) E 3 


1 
(z — Xw)! (z — Xw) — gw Vow 


The posterior in the E step is a truncated Gaussian: 


Gi = N(zn|wlan,1)I (zn >0) if yy =1 
Pinn En W=) NfenjwT an, LT (en <0) if yn = 0 


(15.151) 
(15.152) 


(15.153) 


In Equation (15.152), we see that w only depends linearly on z, so we just need to compute 


; Hn + Em = Hnt Buen) TË Yn = 
l [zn|w, £n] = i bun) dun) ify =O 
Hn — Bun) n 07) Fm5 


-z where un = W En. 


` [zn] Yn, Ln; w], so we just need to compute the posterior mean. One can show that this is given by 


(15.154) 


In the M step, we estimate w using ridge regression, where u = E [z] is the output we are trying 
_ to predict. Specifically, we have 


® = (V7 +X' X) X'p 


(15.155) 


The EM algorithm is simple, but can be much slower than direct gradient methods, as illustrated 


42 in Figure 15.10. This is because the posterior entropy in the E step is quite high, since we only 
43 observe that z is positive or negative, but are given no information from the likelihood about its 


magnitude. Using a stronger regularizer can help speed convergence, because it constrains the range 
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45 of plausible z values. In addition, one can use various speedup tricks, such as data augmentation 
46 [DMO1]. 
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15.4. PROBIT REGRESSION 


15.4.3 Bayesian inference 


It is possible to use the latent variable formulation of probit regression in Section 15.4.2.2 to derive a 

simple Gibbs sampling algorithm for approximating the posterior p(w|D) (see e.g., [AC93; HH06]). 
The key idea is to use an auxiliary latent variable, which, when conditioned on, makes the whole 

model a conjugate linear-Gaussian model. The full conditional for the latent variables is given by 


N(z;\wla;,1)I(z;>0) ify =1 


p(zilYi, £i, w) = { Melee e206). esi (15.156) 


Thus the posterior is a truncated Gaussian. We can sample from a truncated Gaussian, V(z|u,0)I (a < z < b) 
in two steps: first sample u ~ U(®((a— u)/o), ®((b — )/o)), then set z = w+ o~t (u) [Rob95al. 
The full conditional for the parameters is given by 


p(w|D, z, A) = N(wn, Vn) (15.157) 
Vy = (V! + XTX)! (15.158) 
wy = Vn (V5 mo + X” z) (15.159) 


For further details, see e.g., [AC93; FSF10]. It is also possible to use variational Bayes, which 
tends to be much faster (see e.g., [GR06a; FDZ19]). 


15.4.4 Ordinal probit regression 


One advantage of the latent variable interpretation of probit regression is that it is easy to extend to 
the case where the response variable is ordered in some way, such as the outputs low, medium and 
high. This is called ordinal regression. The basic idea is as follows. If there are C output values, 
we introduce C + 1 thresholds 7; and set 


Yn =j if Yj- < Zn <j (15.160) 


where yo < +- < yc. For identifiability reasons, we set yọ = —00, y1 = 0 and yc = oo. For example, 
if C = 2, this reduces to the standard binary probit model, whereby zņ < 0 produces y, = 0 and 
Zn > 0 produces yņ„ = 1. If C = 3, we partition the real line into 3 intervals: (—o0, 0], (0, y2], (ye, 00). 
We can vary the parameter y2 to ensure the right relative amount of probability mass falls in each 
interval, so as to match the empirical frequencies of each class label. See e.g., [AC93] for further 
details. 

Finding the MLEs for this model is a bit trickier than for binary probit regression, since we need 
to optimize for w and y, and the latter must obey an ordering constraint. See e.g., [KL09] for an 
approach based on EM. It is also possible to derive a simple Gibbs sampling algorithm for this model 
(see e.g., [Hof09, p216]). 
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Figure 15.11: Hierarchical Bayesian discriminative models with J groups. (a) Nested formulation. (b) 
Non-nested formulation, with group indicator gn € {1,..., J}. 


15.4.5 Multinomial probit models 


Now consider the case where the response variable can take on C unordered categorical values, 
Yn E {1,...,C}. The multinomial probit model is defined as follows: 


Zne = Wl Ene + Enc (15.161) 
e~ N(0,R) (15.162) 
Yn = arg MaX Zne (15.163) 


£ See e.g., [DE04; GR06b; Sco09; FSF10] for more details on the model and its connection to multinomial 
= logistic regression. 


If instead of setting yn = argmax, Zic We Use Yne = I(zne > 0), we get a model known as 


= multivariate probit, which is one way to model C correlated binary outcomes (see e.g., [TMD12]). 


3, 15.5 Multi-level (hierarchical) GLMs 


7. Suppose we have a set of J related datasets, each of which contains a series of N; datapoints 
, Dj = {(xi,,y),):n=1:N,}. There are 3 main ways to fit models in such a setting: we could fit J 
~ separate models, p(y|x;D,;), which might result in overfitting if some D; are small; we could pool all 
~~ the data to get D = Us_,D; and fit a single model, p(y|x; D), which might result in underfitting; or 


“~~ we can use a hierarchical Bayesian model, also called a mutilevel model or partially pooled 


e Je Je 
N |= |© 


— model, in which we assume each group has its own parameters, 07, but that these have something 
— in common, as modeled by a shared global prior p(0@°). (Note that each group could be a single 


~~ individual.) The overall model has the form 


A Ià Ià IÈ |8 


J Nj 
p(0™7, D) = p(0°) [J |2010.) T] pluie}, 6”) (15.164) 
j=l n=1 
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15.5. MULTI-LEVEL (HIERARCHICAL) GLMS 


See Figure 15.11a, which represents the model using nested plate notation. 

It is often more convenient to represent the model as in Figure 15.11b, which eliminates the 
nested plates (and hence the double indexing of variables) by associating a group indicator variable 
gn E {1,..., J}, which specifies which set of parameters to use for each data point. Thus the model 
now has the form 


J N 
p(6"",D) = p(9°) | ] | n(6"|0°) | Held 0) (15.165) 
j=l n=1 
where 
d . * 
P(Yn|@n gn, 0) = | | p(ynlen, 09)? (15.166) 
j=l 


If the likelihood function is a GLM, this hierarchical model is called a hierarchical GLM [LN96]. 
This class of models is very widely used in applied statistics. For much more details, see e.g., [GH07; 
GHV20; Gel+22]. 


15.5.1 Generalized linear mixed models (GLMMs) 


Suppose that the prior on the per-group parameters is Gaussian, so p(@/|0°) = N(6/|6°, 5). If we 
have a GLM likelihood, the model becomes 


PlYn|En, On =j, 0) = P(Ynle(m)) (15.167) 
Mn = £10) = 2) (0° + 4) = 210° + al ed (15.168) 


where £ is the link function, and ef ~ N(0,=). This is known as a generalized linear mixed 
model (GLMM) or mixed effects model. The shared (common) parameters 0° are called fixed 
effects, and the group-specific offsets ef are called random effects.” We can see that the random 
effects model group-specific deviations or idiosyncracies away from the shared fixed parameters. 
Furthermore, we see that the random effects are correlated, which allows us to model dependencies 
between the observations that would not be captured by a standard GLM. 


15.5.2 Model fitting 


We can fit GLMMs, and hierarchical models more generally, using standard Bayesian inference 
methods. We can use a variety of algorithms, such as HMC (see e.g., [BG13]), variational Bayes (see 
e.g., [HOW11; TN13]), expectation propagation (see e.g., [KW18]), etc. In this section, we use HMC, 
since it is simple and efficient.’ 


2. Note that there are multiple definitions of the terms “fixed effects” and random effects”, as explained in this blog 
post by Andrew Gelman: https://statmodeling.stat.columbia.edu/2005/01/25/why_i_dont_use/. 

3. There are many standard software packages for HMC analysis of hierarchical GLMs, such as Bambi (https: 
//github.com/bambinos/bambi), which is a Python wrapper on top of PyMC, Blackjax and numpyro samplers; 
RStanARM (https://cran.r-project.org/web/packages/rstanarm/index.html), which is an R wrapper on top of 
Stanl and BRMS (https://cran.r-project.org/web/packages/brms/index.html), which is another R wrapper on 
top of Stan, but which also needs a C++ compiler. 
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Figure 15.12: A hierarchical Bayesian linear regression model for the radon problem. 


15.5.3 Example: radon regression 


In this section, we give an example of a hierarchical Bayesian linear regression model. We apply it to 
a simplified version of the radon example from [Gel+14a, Sec 9.4]. 

Radon is known to be the highest cause of lung cancer in non-smokers, so reducing it where possible 
is desirable. To help with this, we fit a regression model, that predicts the (log) radon level as a 


— function of the location of the house, as represented by a categorical feature indicating its county, and 
— a binary feature representing whether the house has a basement or not. We use a dataset consisting 


— of J = 85 counties in Minnesota; each county has between 2 and 80 measurements. 


w Iw [WwW [Ww |W | |W JN 
SIG IE 1S lS IF IS IS 


We assume the following likelihood: 


P(Yn|Ln, Gn = j, 0) = N (Ynlaj + Bj£n, 02) (15.169) 


33 where gn € {1,..., J} is the county for house i, and zn € {0,1} indicates if the floor is at level 0 
34 (i.e., in the basement) or level 1 (i.e., above ground). Intuitively we expect the radon levels to be 
35 lower in houses without basements, since they are more insulated from the earth which is the source 


36 of the radon. 


Pe ees 


Since some counties have very few data points, we use a hierarchical prior in which we assume 
2 


38 aj ~ N(tla,o,) and By ~ N(ug,03). We use weak priors for the parameters: pa ~ N(0,1), 
39 ug ~ N(0,1), oa ~ Cz(1), og ~ C4(1), oy ~ C4 (1). See Figure 15.12 for the graphical model. 


— 15.5.3.1 Posterior inference 


S IS EWE ES 
N IO lo Te [© |b 


43 Figure 15.13 shows the posterior marginals for Ha, pg, aj and 8j. We see that ug is close to -0.6 


with high probability, which confirms our suspicion that having x = 1 (i.e., no basement) decreases 


45 the amount of radon in the house. We also see that the distribution of the a; parameters is quite 
46 variable, due to different base rates across the counties. 
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15.5. MULTI-LEVEL (HIERARCHICAL) GLMS 


alpha beta 


0 1 2 3 —3 =A =i 0 1 


Figure 15.13: Posterior marginals for ac and Be for each county in the radon model. Generated by lin- 
reg_hierarchical_non_ centered. ipynb. 
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Figure 15.14: Predictions from the radon model for 3 different counties in Minnesota. Black dots are observed 
datapoints. Red represents results of hierarchical (shared) prior, blue represents results of non-hierarchical 
prior. Thick lines are the result of using the posterior mean, thin lines are the result of using posterior 
samples. Generated by linreg_hierarchical_non_centered.ipynb. 


Figure 15.14 shows predictions from the hierarchical and non-hierarchical model for 3 different 
counties. We see that the predictions from the hierarchical model are more consistent across counties, 
and work well even if there are no examples of certain feature combinations for a given county (e.g., 
there are no houses without basements in the sample from Cass county). If we sample data from the 
posterior predictive distribution, and compare it to the real data, we find that the RMSE is 0.13 for 
the non-hierarchical model and 0.08 for the hierarchical model, indicating that the latter fits better. 


15.5.3.2 Non-centered parameterization 


One problem that frequently arises in hierarchical models is that the parameters be very correlated. 
This can cause computational problems when performing inference. 
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Figure 15.15: (a) Bivariate posterior p(8;,0g|D) for the hierarchical radon model for county j = 75 using 
centered parameterization. (b) Similar to (a) except we plot p(G;,08|D) for the non-centered parameterization. 
Generated by linreg_hierarchical_non_centered.ipynb. 


Figure 15.15a gives an example where we plot p(8;,0g|D) for some specific county j. If we believe 
that og is large, then 8e is “allowed” to vary a lot, and we get the broad distribution at the top of 
the figure. However, if we believe that og is small, then 8; is constrained to be close to the global 
prior mean of ug, so we get the narrow distribution at the bottom of the figure. This is often called 
Neal’s funnel, after a paper by Radford Neal [Nea03]. It is difficult for many algorithms (especially 
sampling algorithms) to explore parts of parameter space at the bottom of the funnel. This is evident 
from the marginal posterior for og shown (as a histogram) on the right hand side of the plot: we see 
that it excludes the interval [0, 0.1], thus ruling out models in which we shrink 8; all the way to 0. 
In cases where a covariate has no useful predictive role, we would like to be able to induce sparsity, 
so we need to overcome this problem. 

A simple solution to this is to use a non-centered parameterization [PR03]. That is, we replace 
By ~ N (ug, 3) with 8; = ug + B08, where B; ~ N(0,1) represents the offset from the global mean, 
ug. The correlation between B; and og is much less, as shown in Figure 15.15b. See Section 12.6.5 
for more details. 
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1 6 Deep neural networks 


16.1 Introduction 


The term “deep neural network” or DNN, in its modern usage, refers to any kind of differentiable 
function that can be expressed as a computation graph, where the nodes are primitive operations 
(like matrix mulitplication), and edges represent numeric data in the form of vectors, matrices, or 
tensors. In its simplest form, this graph can be constructed as a linear series of nodes or “layers”. 
The term “deep” refers to models with many such layers. 

In Section 16.2 we discuss some of the basic building blocks (node types) that are used in the field. 
In Section 16.3 we give examples of common architectures which are constructed from these building 
blocks. In Section 6.2 we show how we can efficiently compute the gradient of functions defined on 
such graphs. If the function computes the scalar loss of the model’s predictions given a training set, 
we can pass this gradient to an optimization routine, such as those discussed in Chapter 6, in order 
to fit the model. Fitting such models to data is called “deep learning”. 

We can combine DNNs with probabilistic models in two different ways. The first is to use them 
to define nonlinear functions which are used inside conditional distributions. For example, we may 
construct a classifier using p(y|x, 0) = Cat(y|softmax(f(x;@))), where f(a; 0) is a neural network 
that maps inputs x and parameters @ to output logits. Or we may construct a joint probability 
distribution over multiple variables using a directed graphical model (Chapter 4) where each CPD 
p(x;|pa(a;)) is a DNN. This lets us construct expressive probability models. 

The other way we can combine DNNs and probabilistic models is to use DNNs to approximate 
the posterior distribution, i.e., we learn a function f to compute q(z|f(D;@)), where z are the 
hidden variables (latents and/or parameters), D are the observed variables (data), f is an inference 
network, and ¢@ are its parameters; for details, see Section 10.3.6. Note that in this latter, setting the 
joint model p(z,D) may be a “traditional” model without any “neural” components. For example, it 
could be a complex simulator. Thus the DNN is just used for computational purposes, not statistical 
/ modeling purposes. 

More details on DNNs can be found in such books as [Zha+20a; Cho21; Gérl9; GBC16], as well as 
a multitude of online courses. For a more theoretical treatment, see e.g., [Ber+21; Cal20; Aro+21; 
RY2]1]. 


16.2 Building blocks of differentiable circuits 


In this section we discuss some common building blocks used in constructing neural networks. We 
denote the input to a block as a and the output as y. 
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Figure 16.1: An artiticial “neuron”, the most basic building block of a DNN. (a) The output y is a weighted 
combination of the inputs x, where the weights vector is denoted by w. (b) Alternative depiction of the 
neuron’s behavior. The bias term b can be emulated by defining wn = b and Xn = 1. 


16.2.1 Linear layers 


The most basic building block of a DNN is a single “neuron”, which corresponds to a real-valued 
signal y computed by multiplying a vector-valued input signal x by a weight vector w, and then 
adding a bias term b. That is, 


y = f(w;0) =w'a+b (16.1) 


where 0 = (w, b) are the parameters for the function f. This is depicted in Figure 16.1. (The bias 
term is omitted for clarity.) 
It is common to group a set of neurons together into a layer. We can then represent the activations 


26 of a layer with D units as a vector z € R?. We can transform an input vector of activations x into 


=“ an output vector y by multiplying by a weight matrix W, an adding an offset vector or bias term b 
= to get 
y = f(x;0) = Wx +b (16.2) 


32 where 0 = (W, b) are the parameters for the function f. This is called a linear layer, or fully 


connected layer. 

It is common to prepend the bias vector onto the first column of the weight matrix, and to append 
a 1 to the vector æ, so that we can write this more compactly as x = WTZ, where W = [W, b] and 
x = |æ, 1]. This allows us to ignore the bias term from our notation if we want to. 


33 16.2.2 Non-linearities 


A stack of linear layers is equivalent to a single linear layer where we multliply together all the 


~ weight matrices. To get more expressive power we can transform each layer by passing it elementwise 
~~ (pointwise) through a nonlinear function called an activation function. This is denoted by 


y = plx) = [p(21),---, 9(zp)] (16.3) 


45 See Table 16.1 for a list of some common activation functions, and Figure 16.2 for a visualization. 
46 For more details, see e.g., [Mur22, Sec 13.2.3]. 
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1 
5 Name Definition Range Reference 

3 Sigmoid o(a) = r~ 0,1] 

4 Hyperbolic tangent tanh(a) = 2o (2a) — 1 —1, 1] 

5 Softplus c(a) = log(1 + e°) 0, co] [GBB11] 

6 Rectified linear unit ReLU(a) = max(a, 0) 0, co] [GBB11; KSH12a] 
7 Leaky ReLU max(a,0) + a min(a, 0) —oo,co] [MHN13] 

8 Exponential linear unit max(a,0)+min(a(e*—1),0) [-co,oo] [CUH16] 

9 Swish aa(a) —o0,0o] [RZL17| 

10 GELU a®(a) —oo,co] [HG16] 

11 
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Table 16.1: List of some popular activation functions for neural networks. 


Activation function Gradient of activation function 
2.0 5 2.0 
sigmoid fd sigmoid 
154° leaky-relu f jms leaky-relu 


Figure 16.2: (a) Some popular activation functions. “ReLU” stands for “restricted linear unit”. “GELU” 
stands for “Gaussian error linear unit”. (b) Plot of their gradients. Generated by activation_fun_ deriv.ipynb. 


16.2.3 Convolutional layers 


When dealing with image data, we can apply the same weight matrix to each local patch of the 
image, in order to reduce the number of parameters. If we “slide” this weight matrix over the image 
and add up the results, we get a technique known as convolution; in this case the weight matrix is 
often called a “kernel” or “filter”. 

More precisely, let X € R?*™ be the input image, and W € R’*” be the kernel. The output is 
denoted by Z = X ® W, where (ignoring boundary conditions) we have the following:' 


h—1 w—1 


Zij = 5 5 Titu, j+vWu,v (16.4) 


u=0 v=0 


Essentially we compare a local patch of x, of size h x w and centered at (i, j), to the filter w; the 
output just measures how similar the input patch is to the filter. We can define convolution in 1d or 
3d in an analogous manner. Note that the spatial size of the outputs may be smaller than inputs, 


1. Note that, technically speaking, we are using cross correlation rather than convolution. However, these terms are 
used interchangeably in deep learning. 
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Figure 16.3: A 2d convolutional layer with 3 input channels and 2 output channels. The kernel has size 3 x 3 
and we use stride 1 with 0 padding, so the the 6 x 6 input gets mapped to the 4 x 4 output. 


due to boundary effects, although this can be solved by using padding. See [Mur22, Sec 14.2.1] for 
more details. 

We can repeat this process for multiple layers of inputs, and by using multiple filters, we can 
generate multiple layers of output. In general, if we have C input channels, and we want to map it 
to D output (feature) channels, then we define D kernels, each of size h x w x C, where h, w are the 
height and width of the kernel. The d’th output feature map is obtained by convolving all C input 
feature maps with the d’th kernel, and then adding up the results elementwise: 


h-1lw-1C-1 


Zi jd = 5 5 5 Ti+u,j+v,cWu,v,c,d (16.5) 


u=0 v=0 c=0 


30 This is called a convolutional layer, and is illustrated in Figure 16.3. 


The advantage of a convolutional layer compared to using a linear layer is that the weights of 


32 the kernel are shared across locations in the input. Thus if a pattern in the input shifts locations, 


the corresponding output activation will also shift. This is called shift equivariance. In some 


34 cases, we want the output to be the same, no matter where the input pattern occurs; this is called 


shift invariance, and can be obtained by using a pooling layer, which computes the maximum or 
average value in each local patch of the input. (Note that pooling layers have no free (learnable) 


37 parameters.) Other forms of invariance can also be captured by neural networks (see e.g., [CW16; 


FWW21]). 


= 16.2.4 Residual (skip) connections 


If we stack a large number of nonlinear layers together, the signal may get squashed to zero or may 


43 blow up to infinity, depending on the magnitude of the weights, and the nature of the nonlinearities. 
44 Similar problems can plague gradients that are passed backwards through the network (see Section 6.2). 
45 To reduce the effect of this we can add skip connections, also called residual connections, which 
46 allow the signal to skip one or more layers, which prevents it from being modified. For example, 
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16.2. BUILDING BLOCKS OF DIFFERENTIABLE CIRCUITS 


s—_ a Y 


Figure 16.4: A residual connection around a convolutional layer. 


Figure 16.4 illustrates a network that computes 
y = f(a; W) = y(conv(x;W)) + x (16.6) 


Now the convolutional layer only needs to learn an offset or residual to add (or subtract) to the input 
to match the desired output, rather than predicting the output directly. Such residuals are often 
small in size, and hence are easier to learn using neurons with weights that are bounded (e.g., close 
to 1). 


16.2.5 Normalization layers 


To learn an input-output mapping, it is often best if the inputs are standardized, meaning that 
they have zero mean and unit standard deviation. This ensures that the required magnitude of the 
weights is small, and comparable across dimensions. To ensure that the internal activations have this 
property, it is common to add normalization layers. 

The most common approach is to use batch normalization (BN) [IS15]. However this relies 
on having access to a batch of B > 1 input examples. Various alternatives have been proposed 
to overcome the need of having an input batch, such as layer normalization [BKH16], instance 
normalization [UVL16], group normalization [WH18], filter response normalization [SK20], 
etc. More details can be found in [Mur22, Sec 14.2.4]. 


16.2.6 Dropout layers 


Neural networks often have millions of parameters, and thus can sometimes overfit, especially when 
trained on small datasets. There are many ways to ameliorate this effect, such as applying regularizers 
to the weights, or adopting a fully Bayesian approach (see Chapter 17). Another common heuristic 
is known as dropout [Sri+14al, in which edges are randomly omitted each time the network is used, 
as illustrated in Figure 16.5. More precisely, if wy; is the weight of the edge from node 7 in layer 
l— 1 to node j in layer l + 1, then we replace it with 6; = wiijé1, where en ~ Ber(1—p), where p is 
the drop probability, and 1 — p is the keep probability. Thus if we sample en = 0, then all of the 
weights going out of unit 2 in layer l — 1 into any 7 in layer l will be set to 0. 

During training, the gradients will be zero for the weights connected to a neuron which has been 
switched “off”. However, since we resample éy; every time the network is used, different combinations 
of weights will be updated on each step. The result is an ensemble of networks, each with slightly 
different sparse graph structures. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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Figure 16.5: Illustration of dropout. (a) A standard neural net with 2 hidden layers. (b) An example of a 
thinned net produced by applying dropout with p = 0.5. Units that have been dropped out are marked with an 
x. From Figure 1 of [Sri+14a]. Used with kind permission of Geoff Hinton. 
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function weights 


Figure 16.6: Attention layer. (a) Mapping a single query q to a single output, given a set of keys and values. 
From Figure 10.3.1 of [Zha+20a]. Used with kind permission of Aston Zhang. 


At test time, we usually turn the dropout noise off, so the model acts deterministically. To ensure 
the weights have the same expectation at test time as they did during training (so the input activation 
to the neurons is the same, on average), at test time we should use E [@,;;] = wiujE [eu]. For Bernoulli 
noise, we have E [e] = 1 — p, so we should multiply the weights by the keep probability, 1 — p, before 
making predictions. We can, however, use dropout at test time if we wish. This is called Monte 


Carlo dropout (see Section 17.3.1). 

= 16.2.7 Attention layers 
In non-parametric kernel based prediction methods, such as Gaussian processes (Chapter 18), we 
compare the input æ € R® to each of the training examples X = (a1,..., £n) using a kernel to get 


46 a vector of similarity scores, a = [K(a,a;)|?_,. We then use this to retrieve a weighted combination 
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Figure 16.7: (a) Scaled dot-product attention in matrix form. (b) Multi-head attention. From Figure 2 of 
[Vas+17b]. Used with kind permission of Ashish Vaswani. 


of the corresponding m target values y; € R® as follows: 
a= Dow a 
i=1 


See Section 18.3.7 for details. 

We can make a differentiable and parametric version of this as follows (see [Tsa+-19] for details). 
First we replace the stored examples matrix X with a learned embedding, to create a set of stored 
keys, K = W*X c R”*%. Similarly we replace the stored output matrix Y with a learned 
embedding, to create a set of stored values, V = WY Y € R”*®. Finally we embed the input to 
create a query, q = Wx € R“. The parameters to be learned are the three embedding matrices. 

To ensure the output is a differentiable function of the input, we replace the fixed kernel function 
with a soft attention layer. More precisely, we define 


n 


Attn(q, (kı, V1), ...3 (Kris Un)) = Attn(q, (Kin, Vien)) = 5 ailq, kimi (16.8) 


i=1 


where a;(q, kı:n) is the ith attention weight; these weights satisfy 0 < a;(q, kı:n) < 1 for each i 
and > ailq, kin) = 

The attention weights can be computed from an attention score function a(q,k;) € R, that 
computes the similarity of query q to key k;. For example, we can use (scaled) dot product 
attention, which has the form 


alq, k) = q'k/ydk (16.9) 


(The scaling by vdp is to reduce the dependence of the output on the dimensionality of the vectors.) 
Given the scores, we can compute the attention weights using the softmax function: 


exp(a(q, ki)) 
Zj- exP(a(q, k;)) 
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See Figure 16.6 for an illustration. 

In some cases, we want to restrict attention to a subset of the dictionary, corresponding to valid 
entries. For example, we might want to pad sequences to a fixed length (for efficient minibatching), 
in which case we should “mask out” the padded locations. This is called masked attention. We 
can implement this efficiently by setting the attention score for the masked entries to a large negative 
number, such as —10°, so that the corresponding softmax weights will be 0. 

In practice, we usually deal with minibatches of n vectors at a time. Let the corresponding matrices 
of queries, keys and values be denoted by Q € R"*¢*, K € R"™*%, V € R”*%. Let 


i=l 


be the j’th output corresponding to the j’th query. We can compute all outputs Z € R”*% in 
parallel using 


QK 
Z = Attn (Q, K, V) = softmax 
are We 


where the softmax function softmax is applied row-wise. See Figure 16.7(left) for an illustration. 
To increase the flexibility of the model, we often use a multi-head attention layer, as illustrated 
in Figure 16.7(right). Let the ith head be 


)V (16.12) 


hi = Attn(QW’, KWE, vw!) (16.13) 


26 where we E RX, WK e R44 and WY € R¢** are linear projection matrices. We define the 
27 output of the MHA layer to be 


Z = MHA(Q,K, V) = Concat(hy,..., hp») W? (16.14) 


31 where h is the number of heads, and WO e R”4& x4, Having multiple heads can increase performance 
32 of the layer, in the event that some of the weight matrices are poorly initialized; after training, we 
33 can often remove all but one of the heads [MLN19]. 


When the output of one attention layer is used as input to another, the method is called self- 


35 attention. This is the basis of the transformer model, which we discuss in Section 16.3.5. 


-~ 16.2.8 Recurrent layers 


39 We can make the model be stateful by augmenting the input æ with the current state s+, and then 
40 computing the output and the new state using some kind of function: 


(Y, S41) = f(a, s+) (16.15) 


This is called a recurrent layer, as shown in Figure 16.8. This forms the basis of recurrent neural 


45 networks, discussed in Section 16.3.4. In a vanilla RNN, the function f is a simple MLP, but it 
46 may also use attention (Section 16.2.7). 
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Figure 16.8: Recurrent layer. 


16.2.9 Multiplicative layers 


In this section, we discuss multiplicative layers, which are useful for combining different information 
sources. Our presentation follows [Jay+20]. 

Suppose we have inputs x € R” and z € R”, In a linear layer (and, by extension, convolutional 
layers), it is common to concatenate the inputs to get f(a, z) = W[a; z]+b, where W € R**(™+”) and 
b € R*. We can increase the expressive power of the model by using multiplicative interactions, 
such as the following bilinear form: 


f(a, z) = z'Wa +Uz+Va+b (16.16) 
where W € R™*”** is a weight tensor, defined such that 


(z'Wa);, = 5 ziWijk£j (16.17) 

ij 
That is, the k’th entry of the output is the weighted inner product of z and æ, where the weight 
matrix is the k’th “slice” of W. The other parameters have size U € R**™, V € R®*”, and b € R*. 

This formulation includes many interesting special cases. In particular, a hypernetwork [HDL17| 
can be viewed in this way. A hypernetwork is a neural network that generates parameters for another 
neural network. In particular, we replace f(x; 0) with f(x; g9(z;@)). If f and g are affine, this is 
equivalent to a multiplicative layer. To see this, let W’ = z'W + V and b' = Uz + b. If we define 
g(z; ®) = [W’, b’], and f(x; 6) = W'ax +b’, we recover Equation (16.16). 

We can also view the gating layers used in RNNs (Section 16.3.4) as a form of multiplicative 
interaction. In particular, if we the hypernetwork computes the diagonal matrix W’ = o(z'W+V) = 
diag(a1,...,@,,), then we can define f(a, z;@) = a(z) © a, which is the standard gating mechanism. 
Attention mechanisms (Section 16.2.7) are also a form of multiplicative interaction, although they 
involve three-way interactions, between query, key and value. 

Another variant arises if the hypernetwork just computes a scalar weight for each channel of a 
convolutional layer, plus a bias term: 


f(@,z) =a(z)Oxv+4+ b(z) (16.18) 
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Figure 16.9: Explicit vs implicit layers. 


This is called FiLM, which stands for “Feature-wise Linear Modulation” [Per+18]. For a de- 
tailed tutorial on the Film layer and its many applications, see https ://distill.pub/2018/ 
feature-wise-transformations. 


16.2.10 Implicit layers 


So far we have focused on explicit layers, which specify how to transform the input to the output 
using y = f(x). We can also define implicit layers, which specify the output indirectly, in terms of 
a constraint function: 


y € argmin f(x,y) such that g(x,y) = 0 (16.19) 
y 


The details on how to find a solution to this constrained optimization problem can vary depending 
on the problem. For example, we may need to run an inner optimization routine, or call a differential 
equation solver. The main advantage of this approach is that the inner computations do not need to 
be stored explicitly, which saves a lot of memory. Furthermore, once the solution has been found, we 
can propagate gradients through the whole layer, by leveraging the implicit function theorem. This 
lets us use higher level primitives inside an end-to-end framework. For more details, see [GHC21] 
and http: //implicit-layers-tutorial.org/. 


16.3 Canonical examples of neural networks 


In this section, we give several “canonical” examples of neural network architectures that are widely 
used for different tasks. 


* 16.3.1 Multi-layer perceptrons (MLP) 


42 A multi-layer perceptron (MLP), also called a feedforward neural network (FFNN), is one 
43 of the simplest kinds of neural networks. It consists of a series of L linear layers, combined with 


elementwise nonlinearities: 
f(@; 6) = Wrer(Wr-191-1(-:- 91(Wi2) ---)) (16.20) 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o e lw N |e 


Io IN ls la ie le Ie IE 


IS le 


16.3. CANONICAL EXAMPLES OF NEURAL NETWORKS 


Figure 16.10: A feedforward neural network with D inputs, Kı hidden units in layer 1, K2 hidden units in 
layer 2, and C outputs. w is the weight of the connection from node j in layer l— 1 to node k in layer l. 


For example, Figure 16.10 shows an MLP with 1 input layer of D units, 2 hidden layers of Kı and 
Ko units, and 1 output layer with C units. The k’th hidden unit in layer l is given by 


Ki-1 
hO =o [OE wal (16.21) 
j=l 


where y; is the nonlinear activation function at layer l. 

For a classification problem, the final nonlinearity is usually the softmax function. However, it is 
also common for the final layer to have linear activations, in which case the outputs are interpreted 
as logits; the loss function used during training then converts to (log) probabilities internally. 

We can also use MLPs for regression. Figure 16.11 shows how we can make a model for het- 
eroskedastic nonlinear regression. (The term “heteroskedastic” just means that the predicted output 
variance is input-dependent, rather than a constant.) This function has two outputs which compute 
f(a) = E[y|x, 0] and f,(x) = yY [y|x, 0]. We can share most of the layers (and hence parameters) 
between these two functions by using a common “backbone” and two output “heads”, as shown in 
Figure 16.11. For the u head, we use a linear activation, y(a) = a. For the ø head, we use a softplus 
activation, y(a) = c4 (a) = log(1 + e°). If we use linear heads and a nonlinear backbone, the overall 
model is given by 


ply|z, 0) = N (ylw, f(z; Wshared); O+ (w, f(a; Wshared))) (16.22) 


16.3.2 Convolutional neural networks (CNN) 


A vanilla convolutional neural network or CNN consists of a series of convolutional layers, 
pooling layers, linear layers, and nonlinearities. See Figure 16.12 for an example. More sophisticated 
architectures, such as the ResNet model [He+16a; He+16b], add skip (residual) connections, 
normalization layers, etc. The ConvNeXt model of [Liu+22] is considered the current (as of 
February 2022) state of the art CNN architecture for a wide variety of vision tasks. See e.g., [Mur22, 
Ch.14] for more details on CNNs. 
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Figure 16.11: Illustration of an MLP with a shared “backbone” and two output “heads”, one for predicting 
the mean and one for predicting the variance. From https: //brendanhasz. github. t0/ 2019/ 07/ 23/ 
bayesian-density-net. html. Used with kind permission of Brendan Hasz. 
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34 Figure 16.12: One of the first CNNs ever created, for classifying MNIST images. From Figure 3 of [LeC+89]. 
35 For a “modern” implementation, see lecun1989.ipynb. 


38 16.3.3 Autoencoders 


An autoencoder is a neural network that maps inputs x to a low-dimensional latent space using an 
encoder, z = f.(a), and then attempts to reconstruct the inputs using a decoder, « = fy(z). The 


~~. model is trained to minimize 


e Je Jẹ Je Je Je 
A IS IÀ IÈ Jè IS 


L(80) = ||r(x) — zllż (16.23) 


45 where r(x) = fa(fe(a)). (We can also replace squared error with more general conditional log 
46 likelihoods.) See Figure 16.13 for an illustration of a 3 layer AE. 
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Figure 16.13: Illustration of an autoencoder with 3 hidden layers. 


Figure 16.14: (a) Some MNIST digits. (b) Reconstruction of these images using a convolutional autoencoder. 
(c) t-SNE visualization of the 20-d embeddings. The colors correspond to class labels, which were not used 
during training. Generated by ae_mnist_convu_jax.ipynb. 


For image data, we can make the encoder be a convolutional network, and the decoder be a 
transpose convolutional network. We can use this to compute low dimensional embeddings of image 
data. For example, suppose we fit such a model to some MNIST digits. We show the reconstruction 
abilities of such a model in Figure 16.14b. In Figure 16.14c, we show a 2d visualization of the 
20-dimensional embedding space computed using t-SNE. The colors correspond to class labels, which 
were not used during training. We see fairly good separation, showing that images which are visually 
similar are placed close to each other in the embedding space, as desired. (See also Section 21.2.6, 
where we compare AEs with variational AEs.) 
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Figure 16.15: Illustration of a recurrent neural network (RNN). (a) With self-loop. (b) Unrolled in time. 


16.3.4 Recurrent neural networks (RNN) 


A recurrent neural network (RNN) is a network with a recurrent layer, as in Equation (16.15). 
This is illustrated in Figure 16.15. Formally this defines the following probability distribution over 
sequences: 


T 
plyt) = >> pyr, hir) = S$) (ha = hi) p(yilha) | | p(yelre)I (hi = f(hi-1,9-1)) (16.24) 


hi:r hi:r t=2 


where h; is the deterministic hidden state, computed from the last hidden state and last output 


26 using f(hi—1, Yı—-1). (At training time, yY+—ı is observed, but at prediction time, it is generated.) 


In a vanilla RNN, the function f is a simple MLP. However, we can also use attention to selectively 


28 update parts of the state vector based on similarity between the input the previous state, as in the 
29 GRU (gated recurrent unit) model, and the LSTM (long short term memory) model. We can also 
30 make the model into a conditional sequence model, by feeding in extra inputs to the f function. See 
31 e.g., [Mur22, Ch. 15] for more details on RNNs. 


= 16.3.5 Transformers 


35 Consider the problem of classifying each word in a sentence, for example with its part of speech tag 
36 (noun, verb, etc). That is, we want to learn a mapping f : ¥ > Y, where ¥ = VT is the set of 
37 input sequences defined over (word) vocabulary V, T is the length of the sentence, and Y = TT is 
38 the set of output sequences, defined over (tag) vocabulary T. To do well at this task, we need to 


learn a contextual embedding of each word. RNNs process one token at a time, so the embedding of 


40 the word at location t, z+, depends on the hidden state of the network, s+, which may be a lossy 
41 summary of all the previously seen words. We can create bidirectional RNNs so that future words 
42 can also affect the embedding of z+, but this dependence is still mediated via the hidden state. An 
43 alternative approach is to compute z; as a direct function of all the other words in the sentence, 


by using the attention operator discussed in Section 16.2.7 rather than using hidden state. This is 


45 called an (encoder-only) transformer, and is used by models such as BERT [Dev+19]. This idea is 
46 sketched in Figure 16.16. 
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Figure 16.16: Visualizing the difference between an RNN and a transformer. From [Jos20]. Used with kind 
permission of Chaitanya Joshi. 


It is also possible to create a decoder-only transformer, in which each output y+ only attends to all 
the previously generated outputs, y;.;-1. This can be implemented using masked attention, and is 
useful for generative language models, such as GPT (see Section 22.4.1). We can combine the encoder 
and decoder to create a conditional sequence-to-sequence model, p(y1:7,|1-7, ), as proposed in the 
original transformer paper [Vas+17c]. See Supplementary Section 16.1.1 and [PH22] for more details. 

It has been found that large transformers are very flexible sequence-to-sequence function approx- 
imators, if trained on enough data (see e.g., [Lin+2la] for a review in the context of NLP, and 
[Kha+21; Han+20; Zan21] for reviews in the context of computer vision). The reasons why they 
work so well are still not very clear. However, some initial insights can be found in e.g., [Rag+21; 
WGY21; Nel21; BP21]. See also Supplementary Section 16.1.2.5 where we discuss the connection 
with graph neural networks. 


16.3.6 Graph neural networks (GNNs) 


It is possible to define neural networks for working with graph-structured data. These are called 
graph neural networks or GNNs. See Supplementary Section 16.1.2 for details. 
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1 I Bayesian neural networks 


This chapter is coauthored with Andrew Wilson. 


17.1 Introduction 


Deep neural networks (DNNs) are usually trained using a (penalized) maximum likelihood objective 
to find a single setting of parameters. However, large flexible models like neural networks can 
represent many functions, corresponding to different parameter settings, which fit the training data 
well, yet generalize in different ways, a phenoment known as underspecification (see e.g., [D’'A+20], 
and Figure 17.10 for an illustration). Considering all of these different models together can lead to 
improved accuracy and uncertainty representation. This can be done by computing the posterior 
predictive distribution using Bayesian model averaging: 


plylæ, D) = I plylæ, @)p(6|D)d0 (17.1) 


where p(0|D) x p(0)p(D]0). 

The main challenges in applying Bayesian inference to DNNs are specifying suitable priors, and 
efficiently computing the posterior, which is challenging due to the large number of parameters and 
the large datasets. The application of Bayesian inference to DNNs is sometimes called Bayesian 
deep learning or BDL. By contrast, the term deep Bayesian learning or DBL refers to the 
use of deep models to help speedup Bayesian inference of “classical” models, usually by training 
amortized inference networks that can be used as part of a variational inference or importance 
sampling algorithm, as discussed in Section 10.3.6.) For more details on the topic of BDL, see e.g., 
[PS17; Wil20; WI20; Jos+22; Kha20]. 


17.2 Priors for BNNs 


To perform Bayesian inference for the parameters of a DNN, we need to specify a prior p(@). [Nal18; 
W120; For21] discusses the issue of prior selection at length. Here we just give a brief summary of 
common approaches. 
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17.2.1 Gaussian priors 


Consider an MLP with one hidden layer with activation function y and a linear output: 
f(@; 8) = Way(Wiz + b1) + be (17.2) 


(If the output is nonlinear, such as a softmax transform, we can fold it into the loss function during 
training.) If we have two hidden layers this becomes 


f (x; 6) =W3 (p (Woy(W,2 + bı) + b2)) + b3 (17.3) 
In general, with L — 1 hidden layers and a linear output, we have 
f (x; @) = W; (---p(Wıx + b1)) + br (17.4) 


We need to specify the priors for W; and b; for l = 1: L. The most common choice is to use a 
factored Gaussian prior: 


We ~ N(0, 71), be ~ N(O, 871) (17.5) 
The Xavier initialization or Glorot initialization, named after the first author of [GB10], is to 
set 

2 
= (17.6) 


Nin + Nout 


where nin is the fan-in of a node in level Z (number of weights coming into a neuron), and Nout is 
the fan-out (number of weights going out of a neuron). LeCun initialization, named after Yann 
LeCun, corresponds to using 


1 
a3 = — (17.7) 


We can get a better understanding of these priors by considering the effect they have on the 
corresponding distribution over functions that they define. To help understand this correspondence, 


33 let us reparameterize the model as follows: 


We = arne, Nne ~ N (0,1), be = Beee, ee ~ N (0,1) (17.8) 


~~ Hence every setting of the prior hyperparameters specifies the following random function: 


f(x; a, 8) =arnz,(--- (aim @ + Bie1)) + LEL (17.9) 


To get a feeling for the effect of these hyperparameters, we can sample MLP parameters from this 


41 prior and plot the resulting random functions. We use a sigmoid nonlinearity, so y(a) = o(a). We 
42 consider L = 2 layers, so W; are the input-to-hidden weights, and W»2 are the hidden-to-output 
43 weights. We assume the input and output are scalars, so we are generating random nonlinear 1d 


mappings f : R > R. 
Figure 17.1(a) shows some sampled functions where a, = 5, 61 = 1, ag = 1, b2 = 1. In 


46 Figure 17.1(b) we increase aj; this allows the first layer weights to get bigger, making the sigmoid-like 
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Figure 17.1: The effects of changing the hyperparameters on an MLP with one hidden layer. (a) Random 
functions sampled from a Gaussian prior with hyperparameters a4 5, By 1, ag 1, Po 1. (b) 
Increasing ai by factor of 5. (c) Increasing Bı by factor of 5. (d) Inreasing a2 by factor of 5. Generated by 
mlp_priors_ demo.ipynb. 


shape of the functions steeper. In Figure 17.1(c), we increase (1; this allows the first layer biases to 
get bigger, which allows the center of the sigmoid to shift left and right more, away from the origin. 
In Figure 17.1(d), we increase a2; this allows the second layer linear weights to get bigger, making 
the functions more “wiggly” (greater sensitivity to change in the input, and hence larger dynamic 
range). 

The above results are specific to the case of sigmoidal activation functions. ReLU units can behave 
differently. For example, [W120, App. E] show that for MLPs with ReLU units, if we set 6, = 0, so 
the bias terms are all zero, the effect of changing œe is just to rescale the output. To see this, note 
that Equation (17.9) simplifies to 


f(x; a, B = 0) = arny(-:-plaim#)) = ar: angl e(m£)) (17.10) 
=ar: -a f(x; (a = 1,8 = 0)) (17.11) 


where we used the fact that for ReLU, y(az) = ay(z) for any positive a, and y(az) = 0 for any 
negative a (since the pre-activation z > 0). In general, it is the ratio of a and 6 that matters for 
determining what happens to input signals as they propagate forwards and backwards through a 
randomly initialized model; for details, see e.g., [Bah+20]. 

We see that initializing the model’s parameters at a particular random value is like sampling a 
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point from this prior over functions. In the limit of infinitely wide neural networks, we can derive 
this prior distribution analytically: this is known as a neural network Gaussian process, and is 
explained in Section 18.7. 


17.2.2 Sparsity-promoting priors 


Although Gaussian priors are simple and widely used, they are not the only option. For some 
applications, it is useful to use sparsity promoting priors, such as the Laplace, which encourage 
most of the weights (or channels in a CNN) to be zero (c.f., Section 15.2.5). For details, see [Hoe+21]. 


17.2.3 Learning the prior 


We have seen how different priors for the parameters correspond to different priors over functions. 
We could in principle set the hyperparameters (e.g., the a and 3 parameters of the Gaussian prior) 
using grid search to optimize cross-validation loss. However, cross-validation can be slow, particularly 
if we allow different priors for each layer of the network, as our grid search will grow exponentially 
with the number of hyperparameters we wish to determine. 

An alternative is to use gradient based methods to optimize the marginal likelihood 


log p(Dla, 8) = : log p(D|6)p(6|ax, 8)d6 (17.12) 


This approach is known as empirical Bayes (Section 3.8) or evidence maximization, since 


26 log p(D|a, B) is also called the evidence [Mac92a; WS93; Mac99b]. This can give rise to sparse models, 
27 as we discussed in the context of automatic relevancy determination (Section 15.2.7). Unfortunately, 
28 computing the marginal likelihood is computationally difficult for large neural networks. 


Learning the prior is more meaningful if we can do it on a separate, but related dataset. In 
[SZ+-22] they propose to train a model on an initial, large dataset Dı (possibly unsupervised) to 


> get a point estimate, 01, from which they can derive an approximate low-rank Gaussian posterior, 
== using the SWAG method (Section 17.3.8). They then use this informative prior when fine-tuning 


the model on a downstream dataset D2/ The fine-tuning can either be a MAP estimate ô» or some 


= approximate posterior, p(02|D2, D1), e.g., computed using MCMC (Section 17.3.7). They call this 


technique “Bayesian transfer learning”. (See Section 19.5.1 for more details on transfer learning.) 


17.2.4 Priors in function space 


40 Typically, the relationship between the prior distribution over parameters and the functions preferred 
41 by the prior is not transparent. In some cases, it can be possible to pick more informative priors 


based on principles such as desired invariances that we want the function to satisfy (see e.g., [Nall8]). 


43 [FBW2]1] introduces residual pathway priors, providing a mechanism for encoding high level concepts 


into prior distributions, such as locality, independencies, and symmetries, without constraining model 
flexibility. A different approach to encoding interpretable priors over functions leverages kernel 


46 methods such as Gaussian processes (e.g., [Sun+19a]), as we discuss in Section 18.1. 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o e lw N e 


Io a ls la le le Ie IE 


IS Is 


17.3. POSTERIORS FOR BNNS 


17.2.5 Architectural priors 


Beyond specifying the parametric prior, it is important to note that the architecture of the model 
can have an even larger effect on the induced distribution over functions, as argued in Wilson and 
Izmailov [W120] and Izmailov et al. [Izm+21b]. For example, a CNN architecture encode prior 
knowledge about translation equivariance, due to its use of convolution, and hierarchical structure, 
due to its use of multiple layers. Other forms of inductive bias are induced by different architectures, 
such as RNNs. (Models such as transformers have weaker inductive bias, but consequently often 
need more data to perform well.) Thus we can think of the field of neural architecture search 
(reviewed in [EMH19]) as a form of structural prior learning. 

In fact, with a suitable architecture, we can often get good results using random (untrained) models. 
For example, Ulyanov, Vedaldi, and Lempitsky [UVL18] showed that an untrained CNN with random 
parameters (sampled from a Gaussian) often works very well for low-level image processing tasks, 
such as image denoising, super-resolution and image inpainting. The resulting prior over functions 
has been called the deep image prior. Similarly, Pinto and Cox [PC12] showed that untrained 
CNNs with the right structure can do well at face recognition. Moreover, Zhang et al. [Zha+-17| 
show that randomly initialized CNNs can process data to provide features that greatly improve the 
performance of other models, such as kernel methods. 


17.3 Posteriors for BNNs 


There are a large number of different approximate inference schemes that have been applied to 
Bayesian neural networks, with different strengths and limitations. In the sections below, we briefly 
describe some of these. 


17.3.1 Monte Carlo dropout 


Monte Carlo dropout (MCD) [GG16; KG17] is a very simple and widely used method for 
approximating the Bayesian predictive distribution. Usually stochastic dropout layers are added as a 
form of regularization, and are “turned off” at test time, as described in Section 16.2.6, However, the 
idea in MCD is to also perform random sampling at test time. More precisely, we drop out each 
hidden unit by sampling from a Bernoulli(p) distribution; we repeat this procedure S times, to create 
S distinct models. We then create an equally weighted average of the predictive distributions for 
each of these models: 


S 
1 
plylæ, D) = 5 X plylæ, 6°) (17.13) 
s=1 


where 6° is a version of the MAP parameter estimate where we randomly drop out some connections. 

We give an example of this process in action in Figure 17.2. We see that it succesfully captures 
uncertainty due to “out of distribution” inputs. (See Section 19.3.2 for more discussion of OOD 
detection.) 

One drawback of MCD is that it is slow at test time. However this can be overcome by “distilling” 
the model’s predictions into a deterministic “student” network, as we discuss in sec:distillation. 

A more fundamental problem is that MCD does not give proper uncertainty estimates, as argued in 
[Osb16; LF+21]. The problem is the following. Although MCD can be viewed as a form of variational 
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Figure 17.2: Illustration of MC dropout applied to the LeNet architecture. The inputs are some rotated images 
of the digit 1 from the MNIST dataset. (a) shows softmax inputs (logits) and the (b) shows the softmax 
outputs (proabilities). We see that the inputs are classified as digit 7 for the last three images (as shown by 
the probabilities), even though the model has high uncertainty (as shown by the logits). Adapted from Figure 4 
of [GG16]. Generated by mnist_ classification. mc_ dropout.ipynb 


inference [GGI16], this is only true under a degenerate posterior approximation, corresponding to a 
mixture of two delta functions, one at 0 (for dropped out nodes) and one at the MLE. This posterior 
will not converge to the true posterior (which is a delta function at the MLE) even as the training 
set size goes to infinity, since we are always dropping out hidden nodes with a constant probability p 


5, [Osb16]. Fortunately this pathology can be fixed if the noise rate is optimized [GHK17]. For more 
>o details, see e.g., [HGMG18; NHLS19; LF+21]. 


22 17.3.2 Laplace approximation 


22 Tn Section 7.4.3, we introduced the Laplace approximation, which computes a Gaussian approximation 
3% to the posterior, p(@|D), centered at the MAP estimate, 6*. The posterior prediction matrix is equal 
22 to the Hessian of the negative log joint computed at the mode. The benefits of this approach are that 
2 it is simple, and it can be used to derive a Bayesian estimate from a pretrained model. A detailed 
21 explanation of the method in the context of DNNs can be found in [Dax+21]; here we just give a 


= summary. 


A Je Jẹ Je A Jẹ Je Jè Jw 
A J8 IA lÈ IÀ I Ie [5 |8 


Let f(£n,0) € R© be the prediction function with C outputs, and @ € R? be the parameter 


“ vector. Let r(y; f) = Vz log p(y|f) be the residual’, and A(y; f) = -V4 log p(y|f) be the per-input 
~- noise term. In addition, let J € ROP be the Jacobian, [Jo(x)Jei = ee and H € RO*XPXP be 


43 the Hessian, [He (æ)]ci; = #fe(#9) Then the gradient and Hessian of the log likelihood are given by 


30,00; 


46 1. In the Gaussian case, this term becomes V¢||y — f||? = 2||y — f|], so can be interpreted as a residual error. 
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the following [IKB21]: 


Vo log n(yl f(x, )) = Jo(x)'r(y; f) (17.14) 
V6 log p(y| f(a, 8)) = Ho(a)'r(y; f) — Jo(x)' Ay; f)Jo(0) (17.15) 


Since the network Hessian H is usually intractable to compute, it is usually dropped, leaving only the 
Jacobian term. This is called the generalized Gauss-Newton or GGN approximation [Sch02; 
Mar20]. The GGN approximation is guaranteed to be positive definite. By contrast, this is not 
true for the original Hessian in Equation (17.15), since the objective is not convex. Furthermore, 
computing the Jacobian term is cheaper to compute than the Hessian. 

Putting it all together, for a Gaussian prior, p(@) = N’(@|mo, So), the Laplace approximation 
becomes p(0|D) ~ (NV'|@*, Scan), where 


N 
Egon = X Tox (an)'A(Yni fn)So* (2n) + So" (17.16) 


n=l 


Unfortunately inverting this matrix takes O(P*) time, so for models with many parameters, further 
approximations are usually used. The simplest is to use a diagonal approximation, which takes O(P) 
time and space. A more sophisticated approach is presented in [RBB18a], which leverages the KFAC 
(Kronecker FActored Curvature) approximation of [MG15]. This approximates the covariance of each 
layer using a Kronecker product. 

A limitation of the Laplace approximation is that the posterior covariance is derived from the 
Hessian evaluated at the MAP parameters. This means Laplace forms a highly local approximation: 
even if the non-Gaussian posterior could be well-described by a Gaussian distribution, the Gaussian 
distribution formed using Laplace only captures the local characteristics of the posterior at the 
MAP parameters — and may therefore suffer badly from local optima, providing overly compact 
or diffuse representations. In addition, the curvature information is only used after the model has 
been estimated, and not during the model optimization process. By contrast, variational inference 
(Section 17.3.3) can provide more accurate approximations for comparable cost. 


17.3.3 Variational inference 


In fixed-form variational inference (Section 10.3), we choose a distribution for the posterior approxi- 
mation gy(@)and minimize Dxz (q || p), with respect to a. We often choose a Gaussian approximate 
posterior, gy(@) = N(6|, £), which lets us use the reparameterization trick to create a low variance 
estimator of the gradient of the ELBO (see Section 10.3.3). Despite the use of a Gaussian, the 
parameters that minimize the KL objective are often different what we would find with the Laplace 
approximation (Section 17.3.2). 

Variational methods for neural networks date back to at least Hinton and Camp [HC93]. In deep 
learning, [Gral1] revisited variational methods, using a Gaussian approximation with a diagonal 
covariance matrix. This approximates the distribution of every parameter in the model by a univariate 
Gaussian, where the mean is the point estimate, and the variance captures the uncertainty, as shown in 
Figure 17.3. This approach was improved further in [Blu+15], who used the reparameterization trick 
to compute lower variance estimates of the ELBO; they called their method Bayes by backprop 
(BBB). 
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Figure 17.3: Illustration of an MLP with (Left) point estimate for each weight, (Right) a marginal distribution 
for each weight, corresponding to a fully factored posterior approximation. 


In [Blu+15], they used a diagonal Gaussian posterior and vanilla SGD. In [Osa+19a], they used 
the variational online Gauss-Newton (VOGN) method of [Kha+18], for improved scalability. 
VOGN is a noisy version of natural gradient descent, where the extra noise emulates the effect 
of variational inference. In [Mis+18], they replaced the diagonal approximation with a low-rank 
plus diagonal approximation, and used VOGN for fitting. In [Tra+20b], they use a rank-one plus 
diagonal approximation known as NAGVAC (see Section 10.3.4.2). In this case, there are only 3 
times as many parameters as when computing a point estimate (for the variational mean, variance, 
and rank-one vector), making the approach very scalable. In addition, in this case it is possible to 
analytically compute the natural gradient, which speeds up model fitting (see Section 6.4). Many 
other variational methods have also been proposed (see e.g., [LW16; Zha+18; Wu+19a; HHK19]). 
See also Section 17.5.4 for a discussion of online VI for DNNs. 


29 17.3.4 Expectation propagation 


3, Expectation propagation (EP) is similar to variational inference, except it locally optimizes Dga (p || q) 


instead of Dxx (q || p), where p is the exact posterior and q is the approximate posterior. For details, 
see Section 10.7. 

A special case of EP is the assumed density filtering (ADF) algorithm of Section 8.9, which is 
equivalent to the first pass of ADF. In Section 8.9.3 we show how to apply ADF to online logistic 
regression. In [HLA15al], they extend ADF to the case of BNNs; they called their method probabilistic 
backpropagation or PBP. They approximate every parameter in the model by a Gaussian factor, as 
in Figure 17.3. See Section 17.5.3 for the details. 


= 17.3.5 Last layer methods and SNGP 


42 A very simple approximation is to only “be Bayesian” about the weights in the final layer, and to 
43 use MAP estimates for all the other parameters. This is called the neural-linear approximation 


[RTS18]. In [KHH20] they show this can reduce overconfidence in predictions for inputs that are far 


45 from the training data. However, this approach ignores uncertainty introduced by the earlier feature 
46 extraction layers, where most of the parameters reside. We discuss a solution to this in Section 17.3.6. 
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Figure 17.4: Illustration of an MLP fit to the two-moons dataset using HMC. (a) Posterior mean. (b) 
Posterior standard derivation. The uncertainty increases as we move away from the training data. Generated 
by bnn_mlp_ 2d_hmce.ipynb. 


17.3.6 SNGP 


It is possible to combine DNNs with Gaussian Process (GP) models (Chapter 18), by using the DNN 
to act as a feature extractor, which is then fed into the kernel in the final layer. This is called “deep 
kernel learning” (see Section 18.6.6). 

One problem with this is that the feature extractor may lose information which is not needed for 
classification accuracy, but which is needed for robust performance on out-of-distribution inputs (see 
Section 17.4.6.2). The basic problem is that, in a classification problem, there is no reduction in 
training accuracy (log likelihood) if points which are far away are projected close together, as long as 
they are on the correct side of the decision boundary. Thus the distances between two inputs can be 
erased by the feature extraction layers, so that OOD inputs appear to the final layer to be close to 
the training set. 

One solution to this is to use the SNGP (spectrally normalized Gaussian process) method of 
[Liu+20d]. This constrains the feature extraction layers to be “distance preserving”, so that two 
inputs that are far apart in input space remain far apart after many layers of feature extraction, bu 
using spectral normalization of the weights to bound the Lipschitz constant of the feature extractor. 
The overall approach ensures that information that is relevant for computing the confidence of a 
prediction, but which might be irrelevant to computing the label of a prediction, is not lost. This 
can help performance in tasks such as out-of-distribution detection (Section 17.4.6.2). 


17.3.7 MCMC methods 


Some of the earliest work on inference for BNNs was done by Radford Neal, who proposed to use 
Hamiltonian Monte Carlo (Section 12.5) to approximate the posterior [Nea96]. This is generally 
considered the gold standard method, since it does not make strong assumptions about the form of 
the posterior. For more recent work on scaling up HMC for BNNs, see e.g., [Izm+21b; CJ21]. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


N JN JIN JIN Jw JH JIN je JR je je je j= IR JR je je IO I% IN ID Io e Iw IN [RB 


618 


Test error (%) 


\ “ns 
} 
i 238 
| 
f 2.24 
n 19.95 
50 


10 20 30 40 


Figure 17.5: Illustration of stochastic weight averaging (SWA). The three crosses represent different SGD 
solutions. The star in the middle is the average of these parameter values. From Figure 1 of [Izm+18]. Used 
with kind permission of Andrew Wilson. 


We give a simple example of vanilla HMC in Figure 17.4, where we fit a shallow MLP to a small 
2d binary dataset. we plot the mean and standard deviation of the posterior predictive distribution, 
ply = 1\a;D). We see that the uncertainty is higher as we move away from the training data. 
(Compare to Bayesian logistic regression in 1d in Figure 15.8a.) 

However, a significant limitation of standard MCMC procedures, including HMC, is that they 
require access to the full training set at each step. Stochastic gradient MCMC methods, such as 
SGLD, operate instead using mini-batches of data, offering a scalable alternative, as we discuss in 
Section 12.7.1. For an example of SGLD applied to an MLP, see Section 19.3.3.1. 


2° 17.3.8 Methods based on the SGD trajectory 


In [MHB17; SL18; CS18], it was shown that, under some assumptions, the iterates produced by 
stochastic gradient descent (SGD), when run at a fixed learning rate, correspond to samples from 
a Gaussian approximation to the posterior centered at a local mode, p(@|D) ~ N(0|0, £). We can 


31 therefore use SGD to generate approximate posterior samples. This is similar to SG-MCMC methods, 


except we do not add explicit gradient noise, and the learning rate is held constant. 
In [Izm~+ 18], they noted that these SGD solutions (with fixed learning rate) surround the periphery 


34 of points of good generalization, as shown in Figure 17.5. This is in part because SGD does not 


converge to a local optimum unless the learning rate is annealed to 0. They therefore proposed to 
compute the average of several SGD samples, each one collected after a certain interval (e.g., one 
epoch of training), to get 0 = Iy 6,. They call this stochastic weight averaging (SWA). 
They showed that the resulting point tends to correspond to a broader local minimum than the SGD 
solutions (c.f., Figure 17.9), resulting in better generalization performance. 

The SWA approach is related to Polyak-Ruppert averaging, which is often used in convex optimiza- 


41 tion. The difference is that Polyak-Ruppert typically assumes the learning rate decays to zero, and 
42 uses an exponential moving average (EMA) of iterates, rather than an equal average; Polyak-Ruppert 
43 averaging is mainly used to reduce variance in the SGD estimate, rather than as a method to find 


points of better generalization. 
The SWA approach is also related to snapshot ensembles |Hua+17a], and fast geometric 


46 ensembles |Gar+18c]; these methods save the parameters 0, after increasing and decreasing the 
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Figure 17.6: Cartoon illustration of the NLL as it varies across the parameter space. Subspace methods 
(red) model the local neighborhood around a local mode, whereas ensemble methods (blue) approximate the 
posterior using a set of distinct modes. From Figure 1 of [FHL19]. Used with kind permission of Balaji 
Lakshminarayanan. 


learning rate multiple times in a cyclical fashion, and then computing the average of the predictions 
using p(y|x,D) ~ 4 5$ p(yl|a, 0s), rather than computing the average of the parameters and 
predicting with a single model (which is faster). Moreover, by finding a flat region, representing a 
“center or mass” in the posterior, SWA can be seen as approximating the Bayesian model average in 
Equation 17.1 with a single model. 

In [Mad+19], they proposed to fit a Gaussian distribution to the set of samples produced by SGD 
near a local mode. They use the SWA solution as the mean of the Gaussian. For the covariance 
matrix, they use a low-rank plus diagonal approximation of the form p(@|D) = N (0|0, ©), where 
X = (Xaiag + Xr)/2, Ddiag = diag(@? — (0)°), 0 = ayn 6,, 0? = De 6, and Dr = 444" 
is the sample covariance matrix of the last K samples of A; = (0; — 0;), where 0; is the running 
average of the parameters from the first i samples. They call this method SWAG, which stands for 
“stochastic weight averaging with Gaussian posterior”. This can be used to generate an arbitrary 
number of posterior samples at prediction time. They show that SWAG scales to large residual 
networks with millions of parameters, and large datasets such as ImageNet, with improved accuracy 
and calibration over conventional SGD training, and no additional training overhead. 


17.3.9 Deep ensembles 


Many conventional approximate inference methods focus on approximating the posterior p(0|D) in a 
local neighborhood around one of the posterior modes. While this is often not a major limitation in 
classical machine learning, modern deep neural networks have highly multi-modal posteriors, with 
parameters in different modes giving rise to very different functions. On the other hand, the functions 
in a neighborhood of a single mode may make fairly similar predictions. So using such a local 
approximation to compute the posterior predictive will underestimate uncertainty and generalize 
more poorly. 

A simple alternative method is to train multiple models, and then to approximate the posterior 
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using an equally weighted mixture of delta functions, 
1 r 
p(O|D) ~ — X` 8(0 — Om) (17.17) 


where M is the number of models, and 6, is the MAP estimate for model m. See Figure 17.6 for a 
sketch. This approach is called deep ensembles [LPB17; FHL19]. 

The models can differ in terms of their random seed used for initialization [LPB17], or hyper- 
parameters [Wen+20c], or architecture [Zai+-20], or all of the above. In addition, [DF21; TB22] 
discusses how to add an explicit repulsive term to ensure functional diversity between the ensemble 
members. This way, each member corresponds to a distinct prediction function. Combining these is 
more effective than combining multiple samples from the same basin of attraction, especially in the 
presence of dataset shift [Ova+19]. 


17.3.9.1 Multi-SWAG 


We can further improve on this approach by fitting a Gaussian to each local mode using the SWAG 
method from Section 17.3.8 to get a mixture of Gaussians approximation: 


p(O|D) ~ a SN (Olm, Em) (17.18) 


This approach is known as MultiSWAG [W120]. MultiSWAG performs a Bayesian model average 
both across multiple basins of attraction, like deep ensembles, but also within each basin, and provides 
an easy way to generate an arbitrary number of posterior samples, S > M, in an any-time fashion. 


28 17.3.9.2 Deep ensembles with random priors 


The standard way to fit each member of a deep ensemble is to initialize them each with a different 
random set of parameters, but them to train them all on the same data. Unfortunately this can 


35 result in the predictions from each ensemble member being rather similar, which reduces the benefit 


of the approach. One way to increase diversity is to train each member on a different subset of the 
data; this is called bootstrap sampling. Another approach is to define the ith ensemble member 
gi(z) to be the addition of a trainable model t;(x) and a fixed, but random, prior network, p;(x), 
to get 


gi(x; 8;) = ti(@; O;) + Bp;i(x) (17.19) 


2° where 8 > 0 controls the amount of data-independent variation between the members. The trainable 
= network learns to model the residual error between the true output and the value predicted by the 
= prior. This is called a random prior deep ensemble [OAC18]. See Figure 17.7 for an illustration. 


— 17.3.9.3 Deep ensembles as approximate Bayesian inference 


45 The posterior predictive distribution for a Bayesian neural network cannot be expressed in closed 
46 form. Therefore all Bayesian inference approaches in deep learning are approximate. In this context, 
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Figure 17.7: Deep ensemble with random priors. (a) Individual predictions from each member. Blue is the 
fized random prior function, orange is the trainable function, green is the combination of the two. (b) Overall 
prediction from the ensemble, for increasingly large values of B. On the left we show (in red) the posterior 
mean and pointwise standard deviation, and on the right we show samples from the posterior. As B increases, 
we trust the random priors more, and pay less attention to the data, thus getting a more diffuse posterior. 
Generated by randomized_priors.ipynb. 


all approximate inference procedures fall onto a spectrum, representing how closely they approximate 
the true posterior predictive distribution. Deep ensembles can provide better approximations to a 
Bayesian model average than a single basin marginalization approach, because point masses from 
different basins of attraction represent greater functional diversity than standard Bayesian approaches 
which sample within a single basin. 


17.3.9.4 Deep ensembles vs classical ensembles 


Note that deep ensembles is slightly different to classical ensemble methods (see e.g., [Die00]), 
such as bagging and random forests, which obtain diversity of its predictors by training them on 
different subsets of the data (created using bootstrap resampling), or on different features. This data 
perturbation is necessary to get diversity when the base learner is a convex problem (such as a linear 
model, or shallow decision tree). In the deep ensemble approach, every model is trained on the same 
data, and the same input features. The diversity arises due to different starting parameters, different 
random seeds, and SGD noise, which induces different solutions due to the nonconvex loss. It is 
also possible to explicitly enforce diversity of the ensemble members, which can provably improve 
performance [TB22]. 


17.3.9.5 Deep ensembles vs mixtures of experts and stacking 


If we use weighted combinations of the models, p(@|D) = > p(m|D)p(@|m, D), where p(m|D) is 
the marginal likelihood of model m, then, in the large sample limit, this mixture will concentrate on 
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Figure 17.8: Illustration of batch ensemble with 2 ensemble members. From Figure 2 of [WTB20]. Used with 
kind permission of Paul Vicol. 


the MAP model, so only one component will be selected. By contrast, in deep ensembles, we always 
use M equally weighted models. Thus we see that Bayes model averaging is not the same as model 
ensembling [Min00b]. Indeed, ensembling can enlarge the expressive power of the posterior predictive 
distribution compared to BMA [OCM21]. 

We can also make the mixing weights be conditional on the inputs: 


p(y\x, D) = 2 wmle) p(y|x, Om) (17.20) 


If we constrain the weights to be non-zero and sum to one, this is called a mixture of experts. 
However, if we allow a general positive weighted combination, the approach is called stacking [Wol92; 


= Bre96; Yao+18a; CATI20]. In stacking, the weights w(x) are usually estimated on hold-out data, 


to make the method more robust to model misspecification. 


17.3.9.6 Batch ensemble 


34 Deep ensembles require M times more memory and time than a single model. One way to reduce 
35 the memory cost is to share most of the parameters — which we call slow weights, W — and then 


let each ensemble member m estimate its own local perturbation, which we will call fast weights, 
m: We then define Wm = WOF,,. For efficiency, we can define Fm to be a rank-one matrix, 
Fm = Smr}, as illustrated in Figure 17.8. This is called batch ensemble [WTB20]. 
It is clear that the memory overhead is very small compared to naive ensembles, since we just need 


40 to store 2M vectors (s!,, and r!) for every layer, which is negligible compared to the quadratic cost 
41 of storing the shared weight matrix W’. 


In addition to memory savings, batch ensemble can reduce the inference time by a constant factor 


43 by leveraging within-device parallelism. To see this, consider the output of one layer using ensemble 


m on example n: 
ym = p (Wh, an) = 9((WO Smr) tn) = p ((W" (£n © 8m) OTm) (17.21) 
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We can vectorize this for a minibatch of inputs X by replicating Tm and Sm along the B rows in the 
batch to form matrices, giving 


Ym = 9 (((XOSm)W) © Rm) (17.22) 


This applies the same ensemble parameters m to every example in the minibatch of size B. To 
achieve diversity during training, we can divide the minibatch into M sub-batches, and use sub-batch 
m to train Wm. (Note that this reduces the batch size for training each ensemble to B/M.) At test 
time, when we want to average over M models, we can replicate each input M times, leading to a 
batch size of BM. 

In [WTB20], they show that this method outperforms MC dropout at negligible extra memory 
cost. However, the best combination was to combine batch ensemble with MC dropout; in some 
cases, this approached the performance of naive ensembles. 


17.3.10 Approximating the posterior predictive distibution 


Once we have approximated the parameter posterior, g(@) ~ p(@|D), we can use it to approximate 
the posterior predictive distribution: 


pyle, D) = | 4(0)p(y\e, 0)a0 (17.23) 
We usually approximate this integral using Monte Carlo: 
ic 
p(y|x, D) ~ 5 LP ply| f(x, 0°)) (17.24) 


where 0° ~ q(@). We discuss some extensions of this approach below. 


17.3.10.1 A linearized approximation 


In [IKB21] they point out that samples from an approximate posterior, q(0), can result in bad 
predictions when plugged into the model if the posterior puts probability density “in the wrong 
places”. This is because f(a; 0) is a highly nonlinear function of 0 that might behave quite differently 
when 6 is far from the MAP estimate on which q(@) is centered. To avoid this problem, they propose 
to replace f(a; 0) with a linear approximation centered at the MAP estimate 6*: 


fin (2,0) = f(x, 0") + Jø (0 — 6") (17.25) 


Such a model is well behaved around 6*, and so the approximation 


S 
plule, D) = 5 Y pulse (æ, 0°) (17.26) 
s=1 


often works better than Equation (17.24). 

Note that fe. (a, 0) is a linear function of the parameters 0, but a nonlinear function of the 
inputs x. Thus p(y|f% (x,0)) is a generalized linear model (Section 15.1), so [IKB21] call this 
approximation the GLM predictive distribution. 
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17.3.10.2 Distillation 


The MC approximation to the posterior predictive is S times slower than a standard, determin- 
istic plug-in approximation. One way to speed this up is to use distillation to approximate the 
semi-parametric “teacher” model p; from Equation (17.24) by a parametric “student” model ps 
by minimizing E [Dx (p:(y|x) || ps(y|x))] wrt ps. This approach was first proposed in [HVD14], 
who called the technique “dark knowledge”, because the teacher has “hidden” information in its 
predictive probabilities (logits) than is not apparent in the raw one-hot labels. 

In [Kor+15], this idea was used to distill the predictions from a teacher whose parameter posterior 
was computed using HMC; this is called “Bayesian dark knowledge”. A similar idea was used in 
[BPK16; GBP18], who distilled the predictive distribution derived from MC dropout (Section 17.3.1). 

Since the parametric student is typically less flexible than the semi-parametric teacher, it may be 
overconfident, and lack diversity in its predictions. To avoid this overconfidence, it is safer to make 
the student be a mixture distribution c.f., [SG05]. See also [Tra+20a]. 


17.3.11 Tempered and cold posteriors 
When working with BNNs for classification problems, the likelihood is usually taken to be 


p(y|x, 9) = Cat(y|softmax( f(x; 0))) (17.27) 


where f(x;0) € R returns the logits over the C class labels. This is the same as in multinomial 
logistic regression (Section 15.3.2); the only difference is that f is a nonlinear function of 8. 

However, in practice, it is often found (see e.g., [Zha+18; Wen+20b; LST21; Noc+21]) that BNNs 
give better predictive accuracy if the likelihood function is scaled by some power a. That is, instead 
of targeting the posterior p(@|D) x p(y|x, @)p(@), these methods target the tempered posterior, 
Premperea (0|D) x p(y|X, 0)“p(0). In log space, we have 


log Ptemperea(O|D) = a log p(y|X, 0) + log p(@) + const (17.28) 


30 This is also called an a-posterior or power posterior [Med+2]]. 


Another common method is to target the cold posterior, Peoia (O|D) « p(6|X, y)'/", or, in log 
32 space, 
1 1 
log Peoia(O|P) = E log p(y|X, 0) + T log p(0) + const (17.29) 


2 If T < 1, we say that the posterior is “cold”. Note that, in the case of a Gaussian prior, using the 
=- cold prior is the same as using the tempered prior with a different hyperparameter, since 


1 1 
F log N (0|0, 02,41) = ET 2 0? + const = N (0|0, of.mpereal) + const (17.30) 


42 where Ofompered = Toga: Thus both methods are effectively the same, and just reweight the 
43 likelihood. 


Cold posteriors in Bayesian neural network classifiers are a consequence of underrepresenting 


45 aleatoric (label) uncertainty, as shown by [Kap-+22]. On benchmarks such as CIFAR-100, we should 
46 have essentially no uncertainty about the labels of the training images, yet Bayesian classifiers with 
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Figure 17.9: Flat vs sharp minima. From Figures 1 and 2 of [HS97]. Used with kind permission of Jürgen 
Schmidhuber. 


softmax likelihoods have very high uncertainty for these points. Moreover, [Izm+21b] showed that the 
cold posterior effect in all the examples of [Wen+20b] when data augmentation is removed. [Kap+22] 
show that with the SGLD inference in [Wen+20b], data augmentation has the effect of raising the 
likelihood to a power 1/K for minibatches of size K. Cold posteriors exactly counteract this effect, 
more honestly representing our beliefs about aleatoric uncertainty, by sharpening the likelihood. 
However, tempering is not required, and [Kap+22] show that by using a Dirichlet observation 
model to explicitly represent (lack of) label noise, there is no cold posterior effect, even with data 
augmentation. The curation hypotheses of [Ait21] can be considered a special case of the above 
explanation, where curation has the effect of increasing our confidence about training labels. 

In Section 14.1.3, we discuss generalized variational inference, which gives a general framework for 
understanding whether and how the likelihood or prior could benefit from tempering. Tempering is 
particularly useful if (as is usually the case) the model is misspecified [KJD2]]. 


17.4 Generalization in Bayesian deep learning 


In this section, we discuss why “being Bayesian” can improve predictive accuracy and generalization 
performance. 


17.4.1 Sharp vs flat minima 


Some optimization methods (in particular, second-order batch methods) are able to find “needles 
in haystacks”, corresponding to narrow but deep “holes” in the loss landscape, corresponding to 
parameter settings with very low loss. These are known as sharp minima, see Figure 17.9(right). 
From the point of view of minimizing the empirical loss, the optimizer has done a good job. However, 
such solutions generally correspond to a model that has overfit the data. It is better to find points 
that correspond to flat minima, as shown in Figure 17.9(left); such solutions are more robust and 
generalize better. To see why, note that flat minima correspond to regions in parameter space where 
there is a lot of posterior uncertainty, and hence samples from this region are less able to precisely 
memorize irrelevant details about the training set [AS17]. Put another way, the description length for 
sharp minima is large, meaning you need to use many bits of precision to specify the exact location 
in parameter space to avoid incurring large loss, whereas the description length for flat minima is 
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Figure 17.10: Diversity of high performing functions sampled from the posterior. Top row: we show predictions 
on the 1d input domain for 4 different functions. We see that they extrapolate in different ways outside of the 
support of the data. Bottom row: we show a 2d subspace spanning two distinct modes (MAP estimates), and 
connected by a low-loss curved path computed as in [Gar+18c]. From Figure 8 of [W120]. Used with kind 
permission of Andrew Wilson. 


less, resulting in better generalization [Mac03]. 

SGD often finds such flat minima by virtue of the addition of noise, which prevents it from 
“entering” narrow regions of the loss landscape (see Section 12.5.7). In addition, in higher dimensional 
spaces, flat regions occupy a much greater volume, and are thus much more easily discoverable by 
optimization procedures. More precisely, the analysis in [SL18] shows that the probability of entering 


27 any given basin of attraction A around a minimum is given by pscp(@ € A) « Ja e~£(9)d@. Note 
28 that this is integrating over the volume of space corresponding to A, and hence is proportional to 


the model evidence (marginal likelihood) for that region, as explained in Section 3.9.1. Since the 
evidence is parameterization invariant (since we marginalize out the parameters), this means that 


31 SGD will avoid regions that have low evidence (corresponding to sharp minima) regardless of how we 


parameterize the model (contrary to the claims in [Din+17]). 
In fact, several papers have shown that we can view SGD as approximately sampling from the 
Bayesian posterior (see Section 17.3.8). The SWA method (Section 17.3.8) can be seen as finding a 


35 center of mass in the posterior based on these SGD samples, finding solutions that generalize better 
36 than picking a single SGD point. 


If we must use a single solution, a flat one will help us better approximate the Bayesian model 


38 average in the integral of Equation (17.1). However, by attempting to perform a more complete 


Bayesian model average, we will select for flatness without having to deal with the messiness of 


40 having to worry about flatness definitions, or the effects of reparametrization, or unknown implicit 
41 regularization, as the model average will automatically weight regions with the greatest volume. 


= 17.4.2 Mode connectivity and the loss landscape 


45 In DNNs there are often many low-loss solutions, which provide complementary explanations of 
46 the data. Moreover, in [Gar+18c] they showed that two independently trained SGD solutions can 
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be connected by a curve in a subspace, along which the training loss remains near-zero, known as 
mode connectivity. Despite having the same training loss, these different parameter settings give 
rise to very different functions, as illustrated in Figure 17.10, where we show predictions on a 1d 
regression problem coming from different points in parameter space obtained by interpolating along 
a mode connecting curve between two distinct MAP estimates. Using a Bayesian model average, we 
can combine these functions together to provide much better performance over a single flat solution 
[Izm+19]. 

Recently, it has been discovered [Ben+21b] that there are in fact large multidimensional simplexes 
of low loss solutions, which can be combined together for significantly improved performance. These 
results further motivate the Bayesian approach (Equation (17.1)), where we perform a posterior 
weighted model average. 


17.4.3 Effective dimensionality of a model 


Modern DNNs have millions of parameters, but these parameters are often not well-determined 
by the data, i.e., there can be a lot of posterior uncertainty. By averaging over the posterior, we 
reduce the chance of overfitting, because we do not use “degrees of freedom” that are not needed or 
warranted. 

To quantify the number of degrees of freedom, or effective dimensionality [Mac92b], we follow 
[MBW 20] and define 


A 


NalH, ¢) =) >> 


i=1 


(17.31) 


where A; are the eigenvalues of the Hessian matrix H computed at a local mode, and c > 0 isa 
regularization parameter. Intuitively, the effective dimension counts the number of well-determined 
parameters. A “flat minimum” will have many directions in parameter space that are not well- 
determined, and hence will have low effective dimensionality. This means that we can perform 
Bayesian inference in a low dimensional subspace [Izm-+-19]: Since there is functional homogeneity 
in all directions but those defining the effective dimension, neural networks can be significantly 
compressed. 

This compression perspective can also be used to understand why the effective dimension can be a 
good proxy for generalization. If two models have similar training loss, but one has lower effective 
dimension, then it is providing a better compression for the data at the same fidelity. In Figure 17.11 
we show that for CNNs with low training loss (above the green partition), the effective dimensionality 
closely tracks generalization performance. We also see that the number of parameters alone is not a 
strong determinant of generalization. Indeed, models with more parameters can have a lower number 
of effective parameters. We also see that wide but shallow models overfit, while depth helps provide 
lower effective dimensionality, leading to a better compression of the data. It is depth that makes 
modern neural networks distinctive, providing hierarchical inductive biases making it possible to 
discover more regularity in the data. 


17.4.4 The hypothesis space of DNNs 


Zhang et al. [Zha+17| showed that CNNs can fit CIFAR-10 images with random labels with zero 
training error, but can still generalize well on the noise-free test set. It has been claimed that this 
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Figure 17.11: Left: Effective dimensionality as a function of model width and depth for a CNN on CIFAR-100. 
Center: Test loss as a function of model width and depth. Right: Train loss as a function of model width and 
depth. Yellow level curves represent equal parameter counts (1e5, 2e5, 4e5, 1.6e6). The green curve separates 
models with near-zero training loss. Effective dimensionality serves as a good proxy for generalization for 
models with low train loss. We see wide but shallow models overfit, providing low train loss, but high test 
loss and high effective dimensionality. For models with the same train loss, lower effective dimensionality 
can be viewed as a better compression of the data at the same fidelity. Thus depth provides a mechanism for 
compression, which leads to better generalization. From Figure 2 of [MBW20]. Used with kind permission of 
Andrew Wilson. 


result contradicts a classical understanding of generalization, because it shows that neural networks 
are capable of significantly overfitting the data, but can still generalize well on structured inputs. 

We can resolve this paradox by taking a Bayesian perspective. In particular, we know that modern 
CNNs are very flexible, so they can fit almost pattern (since they are in fact universal approximators). 
However, their architecture encodes a prior over what kinds of patterns they expect to see in the data 
(see Section 17.2.5). Image datasets with random labels can be represented by this function class, but 
such solutions receive very low marginal likelihood, since they strongly violate the prior assumptions 
[W120]. By contrast, image datasets where the output labels are consistent with patterns in the 
input get much higher marginal likelihood. 

This phenomenon is not unique to DNNs. For example, it also occurs with Gaussian processes 


35 (Chapter 18). Such models are also universal approximators, but they allocate most of their probability 


mass to a small range of solutions (depending on the chosen kernel). They can also fit image datasets 
with random labels, but such data receives a low marginal likelihood [W120]. 

In general, we can distinguish the support of a model, i.e., the set of functions it can represent, 
from the distribution over that support, i.e., the inductive bias which leads it to prefer some functions 
over others. We would like to use models where the support is large, so we can capture the complexity 
of real-world data, but also where the inductive bias places probability mass on the kinds of functions 
we expect to see. If we succeed at this, the posterior will quickly converge on the true function after 
seeing a small amount of data. This idea is sketched in Figure 17.12. 


= 17.4.5 PAC-Bayes 


44 PAC-Bayes [McA99; LC02; Guel9; Alq21; GSZ21] provides a promising mechanism to derive 
45 non-vacuous generalization bounds for large stochastic networks [Ney+17; NBS18; DR17], with 
46 parameters sampled from a probability distribution. In particular, the difference between the train 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID [o [A [wo N e 


Io IR ls la le le Ie IE 


IS Is 


17.4. GENERALIZATION IN BAYESIAN DEEP LEARNING 


Prior Hypothesis Space 


Posterior / 
O 


Prior Hypothesis Space 


Prior Hypothesis Space 


Posterior O 
G ; > 
True Model 


True Model 


Corrupted 
CIFAR-10 
Structured Image Datasets 


(a) (b) (c) (d) 


Figure 17.12: Illustration of the behavior of different kinds of model families and the prior distribution they 
induce over datasets. (a) The purple model is a simple linear model that has small support, and can only 
represent a few kinds of datasets. The pink model is an unstructured MLP: this has support over a large 
range of datasets with a fairly uninformative (broad) prior. Finally the green model is a CNN; this has 
support over a large range of datasets but the prior is more concentrated on certain kinds of datasets that have 
compositional structure. (b) The posterior for the green model (CNN) rapidly collapses to the true model, 
since it is consistent with the data. (c) The posterior for the purple model (linear) also rapidly collapses, but 
to a solution which cannot represent the true model. (d) The posterior for the pink model (MLP) collapses 
very slowly (as a function of dataset size). From Figure 2 of [W120]. Used with kind permission of Andrew 
Wilson. 


error and the generalization error can be expressed as 


E (QI P)+c (17.32) 


V1? 


where c is a constant and N is the number of training points. P is the prior distribution over the 
parameters and Q is an arbitrary distribution, which can be chosen to optimize the bound. 

The perspective in this chapter is largely complementary, and in some ways orthogonal, to the 
PAC-Bayes literature. Our focus has been on Bayesian marginalization, particularly multi-modal 
marginalization, and a prescriptive approach to model construction. In contrast, PAC-Bayes bounds 
are about bounding the empirical risk of a single sample, rather than marginalization, and are not 
currently prescriptive: what we would do to improve the bounds, such as reducing the number 
of model parameters, or using highly compact priors, does not typically improve generalization. 
Moreover, while we have seen Bayesian model averaging over multimodal posteriors has a significant 
effect on generalization, it has a minimal logarithmic effect on PAC-Bayes bounds. In general, 
because the bounds are lose, albeit non-vacuous in some cases, there is often room to make modeling 
choices that improve PAC-Bayes bounds without improving generalization, making it hard to derive 
a prescription for model construction from the bounds. 


17.4.6 Out-of-Distribution generalization for BNNs 


Bayesian methods are often assumed to be more robust in the context of distribution shift (discussed 
in Chapter 19), because they capture more uncertainty than methods based on point estimation. 
However, there are some subtleties, some of which we discuss below. 
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Figure 17.13: Bayesian neural networks under covariate shift. (a): Performance of a ResNet-20 
on the pizelate corruption in CIFAR-10-C. For the highest degree of corruption, a Bayesian model average 
underperforms a MAP solution by 25% (44% against 69%) accuracy. See Izmailov et al. [Izm+21b] for details. 
(b): Visualization of the weights in the first layer of a Bayesian fully-connected network on MNIST sampled 
via HMC. (c): The corresponding MAP weights. We visualize the weights connecting the input pixels to a 
neuron in the hidden layer as a 28 x 28 image, where each weight is shown in the location of the input pixel it 
interacts with. This is Figure 1 of Izmailov et al. [Izm+21a]. 


17.4.6.1 BMA can give poor results with default priors 


Many approximate inference methods, especially deep ensembles, are significantly less overconfident 
(more well calibrated) in the presence of some kinds of covariate shifts [Ova+19]. However, in 
[Izm+21b], it was noted that HMC, which arguably offers the most accurate approximation to the 
posterior, often works poorly under distribution shift. 

Rather than an idiosyncracy of HMC, Izmailov et al. [Izm+21a] show this lack of robustness 


26 is a foundational issue of Bayesian model averaging under covariate shift, caused by degeneracies 
27 in the training data, and a poor choice of prior. As an illustrative special case, MNIST digits all 
28 have black cornel pixels. Weights in the first layer of a neural network connected to these pixels 
29 are multiplied by zero, and thus can take any value without affecting the outputs of the network. 


Classical MAP training or deep ensembles of MAP solutions with a Gaussian prior will therefore 


31 drive these parameters to zero, since they don’t help with the data fit, and the resulting network 
32 will be robust to corruptions on these pixels. On the other hand, the posterior for these parameters 
33 will be the same as the prior, and so a Bayesian model average will multiply corruptions by random 


numbers sampled from the prior, leading to degraded predictive performance. 
Figure 17.13(b, c) visualizes this example, showing the first-layer weights of a fully-connected 


36 network for the MAP solution and a BNN posterior sample, on MNIST. The MAP weights corre- 
37 sponding to zero intensity pixels near the boundary are near zero, while the BNN weights look noisy, 
38 sampled from a Gaussian prior. 


Izmailov et al. [Izm+21a] prove that this issue is a special case of a much more general problem, 


40 whenever there are linear dependencies in the input features of the training data, both for fully- 
41 connected and convolutional networks. In this case, the data live on a hyperplane. If a covariate or 
42 domain shift, moves orthogonal to this hyperplane, the posterior will be the same as the prior in 
43 the direction of the shift. The posterior model average will thus be highly vulnerable to shifts that 
44 do not particularly affect the underlying semantic structure of the problem (such as corruptions), 
45 whereas the MAP solution will be entirely robust to such shifts. 


By introducing a prior over parameters which is aligned with the principal components of the 
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(c) MC Dropout (d) Bootstrap 


(e) MCMC (f) VI (g) GP (h) SNGP 


Figure 17.14: Predictions made by various (B)NNs when presented with the training data shown in blue and 
red. The green blob is an example of some OOD inputs. Methods are: (a) Standard SGD; (b) Deep Ensemble 
of 10 models with different random initializations; (c) MC Dropout with 50 samples; (d) Bootstrap training, 
where each of the 10 models is initialized identically but given different versions of the data, obtained by 
resampling with replacement; (e) MCMC using NUTS algorithm with 3000 warmup steps and 3000 samples; 
(f) Variational inference; (g) Gaussian process classifier using RBF kernel; (h) SNGP. The model is an MLP 
with 8,16,16,8 units in the hidden layers and ReLu activation. The output layer has 1 neuron with sigmoid 
activation. Generated by makemoons_ comparison.ipynb 


training inputs, we can substantially improve the generalization accuracy of Bayesian neural networks 
in out-of-distribution settings. Izmailov et al. [Izm+2la] propose the following EmpCov prior: 
p(w') = N(0,a + eI), where w! are the first layer weights, © = -+ >7_, xix] is the empirical 
covariance of the training input features x;, œa > 0 determines the scale of the prior, and € is a small 
positive constant to ensure the covariance matrix is positive definite. With this improved prior they 
are able to obtain a method that is much more robust to distribution shift. 


17.4.6.2 BNNs can be overconfident on OOD inputs 


An important problem in practice is how a predictive model will behave when it is given an input that 
is “out of distribution” or OOD. Ideally we would like the model to express that it is not confident 
in its prediction, so that the system can abstain from predicting (see Section 19.3.3). Using “exact” 
inference methods, such as MCMC, for BNNs can give this behavior in some cases. For example, 
in Section 19.3.3.1 we showed that an MLP which was fit to MNIST using SGLD would be less 
overconfident than a point estimate (computed using SGD) when presented with inputs from fashion 
MNIST. However, this behavior does not always occur reliably. 

To illustrate the problem, consider the 2d nonlinear binary classification dataset shown in Fig- 
ure 17.14. In addition to the two training classes, we have highlighted (in green) a set of OOD inputs 
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that are far from the support of the training set. Intuitively we would expect the model to predict 
a probability of 0.5 (corresponding to “don’t know”) for such inputs that are far from the training 
set. However we see that the only methods that do so are the Gaussian Process (GP) classifier (see 
Section 18.4) and the SNGP model (Section 17.3.6), which contains a GP layer on top of the feature 
extractor. 

The lesson we learn from this simple example is that “being Bayesian” only helps if we are using a 
good hypothesis class. If we only consider a single MLP classifier, with standard Gaussian priors on 
the weights, it is extremely unlikely that we will learn the kind of compact decision boundary shown 
in Figure 17.14g, because that function has negligible support under our prior (c.f. Section 17.4.4). 
Instead we should embrace the power of Bayes to avoid overfitting and use as complex a model class 
as we can afford. 


17.4.7 Model Selection for BNNs 


Historically, the marginal likelihood (aka Bayesian evidence) has been used for model selection 
problems, such as choosing neural architectures or hyperparameter values [Mac92a]. Recent methods 
based on the Laplace approximation, such as [Imm-+21; Dax+21], have made this scalable to large 
BNNs. However, [Lot+22] argue that it is much better to use the conditional marginal likelihood, 
which we discuss in Section 3.9.4. 


17.5 Online inference 


In Section 17.3, we have focused on batch or offline inference. However, an important application 
of Bayesian inference is in sequential settings, where the data arrives in a continuous stream, and 
the model has to “keep up”. This is called sequential Bayesian inference, and is one approach to 
online learning (see Section 19.7.5). In this section, we discuss some algorithmic approaches to 
this problem in the context of DNNs. These methods are widely used for continual learning, which 
we discuss Section 19.7. 


31 17.5.1 Sequential Laplace for DNNs 


,, In [RBB18b], they extended the Laplace method of Section 17.3.2 to the sequential setting. Specifically, 
= let p(@|Di+-1) ~ N(O|u,_1, A77) be the approximate posterior from the previous step; we assume 


the precision matrix is Kronecker factored. We now compute the new mean by solving the MAP 
problem 


H = argmax log p(D;|@) + log p(O@|D:_1) (17.33) 
1 = 
= argmaxlog p(D,|0) ~ 5 (0 — py.) Ach (8 ~ pa) (17.34) 


Once we have computed u,, we compute the approximate Hessian at this point, and get the new 
posterior precision 


A, = AH(14,) + Ara (17.35) 


45 where À > 0 is a weighting factor that trades off how much the model pays attention to the new data 
46 vs old data. 
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Now suppose we use a diagonal approximation to the posterior prediction matrix. From Equa- 
tion (17.34), we see that this amounts to adding a quadratic penalty to each new MAP estimate, to 
encourage it to remain close to the parameters from previous tasks. This approach is called elastic 
weight consolidation (EWC) [Kir+17]. 


17.5.2 Extended Kalman Filtering for DNNs 


In Section 29.7.2, we showed how Kalman filtering can be used to incrementally compute the 
exact posterior for the weights of a linear regression model with known variance, i.e., we compute 
p(O|Di+4,07), where Di = {(ui, yi) : i = 1: t} is the data seen so far, and 


pluu, 0,07) = N (yi0 ur, 07) (17.36) 


is the linear regression likelihood. The application of KF to this model is known as recursive least 
squares. 
Now consider the case of nonlinear regression: 


plyus, 8, 07) = N (y| f (0, ut), 07) (17.37) 


where f(O, u+) is some nonlinear function, such as an MLP. We can use the extended Kalman filter 
(Section 8.5.2) to approximately compute p(0+|D1:t, 0°), where 0; is the hidden state (see e.g., [SW89; 
PF03]). To see this, note that we can set the dynamics model to the identity function, f (0+) = 0+, so 
the parameters are propagated through unchanged, and the observation model to the input-dependent 
function f(0:) = f (0+, u+). We set the observation noise to Ry = 07, and the dynamics noise to 
Q: = q1, where q is a small constant, to allow the parameters to slowly drift according to artificial 
process noise. (In practice it can be useful to anneal q from a large initial value to something near 
0.) 


17.5.2.1 Example 


We now give an example of this process in action. We sample a synthetic dataset from the true 
function 


h*(u) = x — 10 cos(u) sin(u) + u’ (17.38) 


and add Gaussian noise with ø = 3. We then fit this with an MLP with one hidden layer with H 
hidden units, so the model has the form 


f0, u) = Wo tanh(Wu + bı) + bo (17.39) 


where W, € R#*!, bı € R”, Ws € R'*", by € R!. We set H = 6, so there are D = 19 parameters 
in total. 

Given the data, we sequentially compute the posterior, starting from a vague Gaussian prior, 
p(0) = N(6|0, Xo), where Xo = 1001. (In practice we cannot start from the prior mean, which is 
0o = 0, since linearizing the model around this point results in a zero gradient, so we use an initial 
random sample for 09.) The results are shown in Figure 17.15. We can see that the model adapts 
to the data, without having to specify any learning rate. In addition, we see that the predictions 
become gradually more confident, as the posterior concentrates on the MLE. 
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Step=10 Step=20 


(a) (b) 
Step=30 Step=200 


Figure 17.15: Sequential Bayesian inference for the parameters of an MLP using the extended Kalman 
filter. We show results after seeing the first 10, 20, 80 and 200 observations. (For a video of this, see 
https: // bit. ly/ 3wXnWal.) Generated by ekf_milp.ipynb. 


~ 17.5.2.2 Setting the variance terms 


30 In the above example, we set the variance terms by hand. In general we need to estimate the noise 
31 variance g, which determines R; and hence the learning rate, as well as the strength of the prior No, 
32 which controls the amount of regularization. Some methods for doing this are discussed in [FNGOO]. 


= 17.5.2.3 Reducing the computational complexity 


36 The naive EKF method described above takes O(N3) time, which is prohibitive for large neural 
37 networks. A simple approximation, known as the decoupled EKF, was proposed in [PF91; SPD92} 
38 (see [PF03] for a review). This partitions the weights into G groups or blocks, and estimates the 
39 relevant matrices for each group g independently. If G=1, this reduces the standard global EKF. 
40 If we put each weight into its own group, we get a fully diagonal approximation. In practice this 
41 does not work any better than SGD, since it ignores correlations between the parameters. A useful 
42 compromise is to put all the weights corresponding to each neuron into its own group; this is called 
43 “node decoupled EKF”. 


Another approach to increasing computational efficiency is to leverage the fact that the effective 


45 dimensionality of a DNN is often quite low (see Section 17.4.3). Indeed we can approximate the 
46 model parameters by using a low dimensional vector of coefficients that specify the point in a linear 
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17.5. ONLINE INFERENCE 


manifold corresponding to weight space; the basis set defining this linear manifold can either be 
chosen randomly [Li+18b; GARD18; Lar+22], or can be estimated using PCA applied to the SGD 
iterates [Izm+19]. We can exploit this observation to perform EKF in this low-dimensional subspace, 
which significantly speeds up inference, as discussed in [DMKM22]. 


17.5.3 Assumed Density Filtering for DNNs 


In Section 8.9.3, we discussed how to use assumed density filtering (ADF) to perform online (binary) 
logistic regression. In this section, we generalize this to nonlinear predictive models, such as DNNs. 
The key is to perform Gaussian moment matching of the hidden activations at each layer of the model. 
This provides an alternative to the EKF approach in Section 17.5.2, which is based on linearization 
of the network. 

We will assume the following likelihood: 


P(yr|ue, we) = Expfam(y;|¢-*(f (ur; we))) (17.40) 


where f(a; w) is the DNN, ¢~! is the inverse link function, and Expfam() is some exponential family 
distribution. For example, if f is linear and we are solving a binary classification problem, we can 
write 


p(Yyel te, wt) = Ber(y:|o (ul wr) (17.41) 


We discussed using ADF to fit this model in Section 8.9.3. 

In [HLA15b], they propose Probabilistic backpropagation (PBP), which is an instance of 
ADF applied to MLPs. The basic idea is to approximate the posterior over the weights in each layer 
using a fully factorized distribution 


L D, Di-it+1 


p(wi|Pis) ~ piw) =] TTT 1 Mwili T) (17.42) 


l=1i=1 j=1 


where L is the number of layers, and D; is the number of neurons in layer l. (The expectation 
backpropagation algorithm of [SHM14] is a special case of this, where the variances are fixed to 
r=) 

Suppose the parameters are static, so w: = w—1. Then the new posterior, after conditioning on 
the t’th observation, is given by 


: 1 F = 
Be(w) = 5 plurleee, wN wla’ LD?) (17.43) 
where ©'~! = diag(r*~1). We then project p;(w) instead the space of factored Gaussians to compute 


the new (approximate) posterior, p,(w). This can be done by computing the following means and 
variances [Min0 1a]: 


OlnZ 
t t—1 t—1 t 
H jl H jl jl ðu ( ) 
OlnZ i ðln Z, 
TÉ. yE — cia} 2 net 2 ce 17.45 
ijl ijl ( ijl ) Tra Ar ( ) 
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In the forwards pass, we compute Z, by propagating the input us through the model. Since we have 
a Gaussian distribution over the weights, instead of a point estimate, this induces an (approximately) 
Gaussian distribution over the values of the hidden units. For certain kinds of activation functions 
(such as ReLU), the relevant integrals (to compute the means and variances) can be solved analytically, 
as in GP-neural networks (Section 18.7). The result is that we get a Gaussian distribution over the 
final layer of the form N (n,|u, ©), where n, = f (uz; wz) is the output of the neural network before 
the GLM link function induced by p;(w;). Hence we can approximate the partition function using 


Dy I pln) N (eles, E)dny, (17.46) 


We now discuss how to compute this integral. In the case of probit classification, with y € {—1, +1}, 
we have p(y|x, w) = ®(y7), where ® is the cdf of the standard normal. We can then use the following 
analytical result 


[umn i a)dn = ® (4) (17.47) 


In the case of logistic classification, with y € {0,1}, we have p(y|a,w) = Ber(y|o(7)); in this case, 
we can use the probit approximation from Section 15.3.5. For the multiclass case, where y € {0, 1}° 
(one-hot encoding), we have p(y|x, w) = Cat(y|softmax(7)). A variational lower bound to log Z+ for 
this case is given in [GDF Y 16]. 

Once we have computed Z;, we can take gradients and update the Gaussian posterior moments, 
before moving to the next step. 


= 17.5.4 Online variational inference for DNNs 


N JIN IN IN JIN Jw JIN JIN |W JN 
s le N |S IS IS SIS IS IS le Ie S le la le le Is E Is 


Q lo te jw IN ie IO 


28 A natural approach to online learning is to use variational inference, where the prior is the posterior 
29 from the previous step. This is known as streaming variational Bayes [Bro+13]. In more detail, 
30 at step t, we compute 


th, = argmin Ey ayy [6(0)] + Da. (401W) | a911) (17.48) 
—S aaas 
—Li (wp) 
= argminE yay) [4e(8) + 10g q(O)x6) — log aO lth- )] (17.49) 


37 where 4(0) = — log p(D;|@) is the negative log likelihood (or, more generally, some loss function) of 
38 the data batch at step t. 


When applied to DNNs, this approach is called variational continual learning or VCL [Ngu+ 18]. 


40 (We discuss continual learning in Section 19.7.) An efficient implementation of this, known as FOO- 
41 VB (“Fixed-point Operator for Online Variational Bayes”) is given in [Zen-+21]. 


One problem with the VCL objective in Equation (17.48) is that the KL term can cause the 


43 model to become too sparse, which can prevent the model from adapting or learning new tasks. 


This problem is called variational overpruning [TT17]. More precisely, the reason this happens 


45 as is as follows: some weights might not be needed to fit a given dataset, so their posterior will be 
46 equal to the prior; but sampling from these high-variance weights will add noise to the likeilhood; to 
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reduce this, the optimization method will prefer to set the bias term to a large negative value, so 
the corresponding unit is “turned off”, and thus has no effect on the likelihood. Unfortunately, these 
“dead units” become stuck, so there is not enough network capacity to learn the next task. 

In [LST21], they propose a solution to this, known as generalized variational continual 
learning or GVCL. The first step is to downweight the KL term by a factor 8 < 1 to get 


Li = Eq(ojy) (6¢(8)] + BDr ((8 |) || albi) (17.50) 


Interestingly, one can show that in the limit of 8 — 0, this recovers several standard methods that 
use a Laplace approximation based on the Hessian. In particular if we use a diagonal variational 
posterior, this reduces to online EWC method of [Sch+18]; if we use a block-diagonal and Kronecker 
factored posterior, this reduces to the online structured Laplace method of [RBB18b]; and if we use 
a low-rank posterior precision matrix, this reduces to the SOLA method of [Yin+20]. 

The second step is to replace the prior and posterior by using tempering, which is useful when 
the model is misspecified, as discussed in Section 17.3.11. In the case of Gaussians, raising the 
distribution to the power A is equivalent to tempering with a temperature of r = 1/A, which is the 
same as scaling the covariance by A71. Thus the GVCL objective becomes 


Ly = Eqcoy) [ee(8)] + 8Dr (aO) || a(Ole-1)*) (17.51) 


This can be optimized using SGD, assuming the posterior is reparameterizable (see Section 10.3.3). 


17.6 Hierarchical Bayesian neural networks 


In some problems, we have multiple related datasets, such as a set of medical images from different 
hospitals. Some aspects of the data (e.g., the shape of healthy vs diseased cells) is generally the same 
across datasets, but other aspects may be unique or idiosyncractic (e.g., each hospital may use a 
different colored die for staining). To model this, we can use a hierarchical Bayesian model, in which 
we allow the parameters for each dataset to be different (to capture random effects), while coming from 
a common prior (to capture shared effects). This is the setup we considered in Section 15.5, where 
we discuss hierarchical Bayesian GLMs. In this section, we extend this to nonlinear predictors based 
on neural networks. (The setup is very similar to domain generalization, discussed in Section 19.6.2, 
except here we care about performance on all the domains, not just a held-out target domain.) 


17.6.1 Example: multi-moons classification 


In this section, we consider an example? where we want to solve multiple related nonlinear binary 
classification problems coming from J different environments or distributions. We assume that each 
environment has its own unique decision boundary p(y|x, wt), so this is a form of concept shift (see 
Section 19.2.3). However we assume the overall shape of each boundary is similar to a common shared 
boundary, denote p(y|z, w°). We only have a small number N; of examples from each environment, 
DI = {(x},y}) : n = 1 : Nj}, but we can utilise their common structure to do better than fitting J 
separate models. 


2. This example is from https://twiecki.io/blog/2018/08/13/hierarchical_bayesian_neural_network/. For a 
real-world example of a similar approach applied to a gesture recognition task, see [Jos+-17]. 
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Two moons dataset 
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Figure 17.16: (a) Two moons synthetic dataset. (b) Multi-task version, where we rotate the data to create 
18 related tasks (groups). Each dataset has 50 training and 50 test points. Here we show the first 4 tasks. 
Generated by bnn_ hierarchical.ipynb. 


To illustrate this, we create some synthetic 2d data for the J = 18 tasks. We start with the 
two-moons dataset, illustrated in Figure 17.16a. Each task is obtained by rotating the 2d inputs by a 
different amount, to create 18 related classification problems (see Figure 17.16b). See Figure 17.16b 
for the training data for 4 tasks. 

To handle the nonlinear decision boundary, we use a multilayer perceptron. Since the dataset is 
low-dimensional (2d input), we use a shallow model with just 2 hidden layers, each with 5 neurons. 
We could fit a separate MLP to each task, but since we have limited data per task (N; = 50 examples 
for training), this works poorly, as we show below. We could also pool all the data and fit a single 


39 model, but this does even worse, since the datasets come from different underlying distributions, so 


mixing the data together from different “concepts” confuses the model. Instead we adopt a hierarchical 
Bayesian approach. 

Our modeling assumptions are shown in Figure 17.17. In particular, we assume the weight from 
unit i to unit k in layer l for environment j, denoted w? kp comes from a common prior value wy kD? 
with a random offset. We use the non-centered parameterization from Section 12.6.5 to write 


I oÀ j 0 
Wi kı = Wik + Ge X FF (17.52) 


where eG ~ N(0,1). By allowing a different o? per layer l, we let the model control the degree 


of shrinkage to the prior for each layer separately. (We could also make the a} parameters be 
environment specific, which would allow for different amounts of distribution shift from the common 
parent.) For the hyper-parameters, we put M (0, 1) priors on w?, ,, and N;(1) priors on o°. 


We compute the posterior p(et'7, w?.,,09.,|D) using HMC (Section 12.5). We then evaluate this 


46 model using a fresh set of labeled samples from each environment. The average classification accuracy 
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Figure 17.17: Illustration of a hierarchical Bayesian MLP with 2 hidden layers. There are J different models, 
each with N; observed samples, and a common set of global shared parent parameters denoted with the 0 
superscript. Nodes which are shaded are observed. Nodes with double ringed circles are deterministic functions 
of their parents. 


Dataset 1 Dataset 2 Dataset 3 Dataset 4 
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Figure 17.18: Results of fitting separate MLPs on each dataset. Generated by bnn _ hierarchical. ipynb. 


on the train and test sets for the non-hierarchical model (one MLP per environment, fit separately) 
is 86% and 83%. For the hierarchical model, this improves to 91% and 89% respectively. 

To see why the hierarchical model works better, we will plot the posterior predictive distribution 
in 2d. Figure 17.18 shows the results for the non-hierarchical models; we see that the method fails to 
learn the common underlying Z-shaped decision boundary. By contrast, Figure 17.19 shows that the 
hierarchical method has correctly recovered the common pattern, while still allowing group variation. 
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Dataset 1 Dataset 2 Dataset 3 Dataset. 4 
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oo Figure 17.19: Results of fitting hierarchical MLP on all datasets jointly. Generated by bnn_ hierarchical. ipynb. 
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1 3 Gaussian processes 


This chapter is co-authored with Andrew Wilson. 


18.1 Introduction 


Deep neural networks are a family of flexible function approximators of the form f(æ; 0), where the 
dimensionality of 0 (i.e., the number of parameters) is fixed, and independent of the size N of the 
training set. However, such parametric models can overfit when N is small, and can underfit when N 
is large, due to their fixed capacity. In order to create models whose capacity automatically adapts 
to the amount of data, we turn to nonparametric models. 

There are many approaches to building nonparametric models for classification and regression (see 
e.g., [Was06]). In this chapter, we consider a Bayesian approach in which we represent uncertainty 
about the input-output mapping f by defining a prior distribution over functions, and then updating 
it given data. In particular, we will use a Gaussian process to represent the prior p(f); we 
then use Bayes rule to derive the posterior p(f|D), which is another GP, as we explain below. 
More details on GPs can be found the excellent book [RW06], as well as the interative tutorial at 
https://distill.pub/2019/visual-exploration-gaussian-processes. See also Chapter 31 for 
other examples of Bayesian nonparametric models. 


18.1.1 GPs: What and why? 


To explain GPs in more detail, recall that a Gaussian random vector of length N, f = [fi,..., fn], 
is defined by its mean p = E [f] and its covariance X = Cov [f]. Now consider a function f : ¥ > R 
evaluated at a set of inputs, X = {æn € VW_,. Let fx =[f(a1),..., f(an)] be the set of unknown 
function values at these points. If fx is jointly Gaussian for any set of N > 1 points, then we 
say that f : XÆ — R is a Gaussian process. Such a process is defined by its mean function 
m(x) € R and a covariance function, K(x, x’) > 0, which is any positive definite Mercer kernel 
(see Section 18.2). For example, we might use an RBF kernel of the form K(a, x’) x exp(—||a—x’||?) 
(see Section 18.2.1.1 for details). 
We denote the corresponding GP by 


f(a) ~ GP(m(a), K(x, x’)) (18.1) 


IO 100 IN ID Jo e lo IN Ie 


BIN IS S IS ISIS IS IS le Ie IR le la le le Is IB ls 


E IS IS IÈ IS ISIE IS IS IS S |S S le |e B Je 1S E 
NID o Te Jo IN Ie IO [© iœ IN Im o [A Iw N e Io o 


642 


Figure 18.1: A Gaussian process for 2 training points, xı and x2, and 1 testing point, x., represented as 
a graphical model representing p(y, fx|X) = N (fx|m(X),K(X))[[; pil fi). The hidden nodes fi = f(x) 
represent the value of the function at each of the data points. These hidden nodes are fully interconnected 
by undirected edges, forming a Gaussian graphical model; the edge strengths represent the covariance terms 
Dij = K(xi, xj). If the test point x. is similar to the training points zı and x2, then the value of the hidden 
function f, will be similar to fı and f2, and hence the predicted output y, will be similar to the training 
values yi and y2. 


where 
m(x) = E[f(a)| (18.2) 
K(a, a’) = E [(f(a#) — m(x))(f (2) — m(x’))"] (18.3) 
This means that, for any finite set of points X = {g1,..., £y}, we have 
P(Fx|X) =N (fx|hx, Kx,x) (18.4) 


28 where py = (m(ay),...,m(ay)) and Kx x(i, j) 4 Klas ep) 


A GP can be used to define a prior over functions. We can evaluate this prior at any set of points 
we choose. However, to learn about the function from data, we have to update this prior with a 


31 likelihood function. We typically assume we have a set of N iid observations D = {(a;, y:i) :i=1: N}, 
32 where y; ~ p(y|f(a#:)), as shown in Figure 18.1. If we use a Gaussian likelihood, we can compute the 
33 posterior p(f|D) in closed form, as we discuss in Section 18.3. For other kinds of likelihoods, we will 


need to use approximate inference, as we discuss in Section 18.4. In many cases f is not directly 


35 observed, and instead forms part of a latent variable model, both in supervised and unsupervised 
36 settings such as in Section 28.3.7. 


The generalization properties of a Gaussian process are controlled by its covariance function 


38 (kernel), which we describe in Section 18.2. These kernels live in a reproducing kernel Hilbert space 
39 (RKHS), described in Section 18.3.7.1. 


GPs were originally designed for spatial data analysis, where the input is 2d. This special case 


41 is called kriging. However, they can be applied to higher dimensional inputs. In addition, while 
42 they have been traditionally limited to small datasets, it is now possible to apply GPs to problems 
43 with millions of points, with essentially exact inference. We discuss these scalability advances in 
44 Section 18.5. 


Moreover, while Gaussian processes have historically been considered smoothing interpolators, GPs 


46 now routinely perform representation learning, through covariance function learning, and multilayer 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


S IO lœ IN ID o e lw N e 


N JIN JIN JIN JIN JR je JR Je je Je Je JR j= 
ISIS IS IS IS le le IS le la le le R IE 


w j N JY JN JY JN 
Ik |S 18 lè IS [8 IS 


32 


18.2. MERCER KERNELS 


models. These advances have clearly illustrated that GPs and neural networks are not competing, 
but complementary, and can be combined for better performance than would be achieved by deep 
learning alone. We describe GPs for representation learning in Section 18.6. 

The connections between Gaussian processes and neural networks can also be further understood 
by considering infinite limits of neural networks that converge to Gaussian processes with particular 
covariance functions, which we describe in Section 18.7. 

So Gaussian processes are non-parametric models which can scale and do representation learning. 
But why, in the age of deep learning, should we want to use a Gaussian process? There are several 
compelling reasons to prefer a GP, including: 


e Gaussian processes typically provide well-calibrated predictive distributions, with a good char- 
acterization of epistemic (model) uncertainty — uncertainty arising from not knowing which of 
many solutions is correct. For example, as we move away from the data, there are a greater variety 
of consistent solutions, and so we expect greater uncertainty. 


e Gaussian processes are often state-of-the-art for continuous regression problems, especially spa- 
tiotemporal problems, such as weather interpolation and forecasting. In regression, Gaussian 
process inference can also typically be performed in closed form. 


e The marginal likelihood of a Gaussian process provides a powerful mechanism for flexible kernel 
learning. Kernel learning enables us to provide long-range extrapolations, but also tells us 
interpretable properties of the data that we didn’t know before, towards scientific discovery. 


e Gaussian processes are often used as a probabilistic surrogate for objectives in optimization, in a 
procedure known as Bayesian optimization (Section 6.8). To maximize an objective, we wish 
to move where there is a high expected value, but also to explore where we have large uncertainty. 
The ability for a Gaussian process to provide closed form inference in regression, in conjunction 
with high quality uncertainty representations, make them particularly impactful in this setting. 
Bayesian optimization has a wide range of applications, including A/B testing, experimental 
design, protein engineering, hyperparameter tuning, and AutoML. See Section 6.8 for details. 


18.2 Mercer kernels 


The generalization properties of Gaussian processes boil down to how we encode prior knowledge 
about the similarity of two input vectors. If we know that æ; is similar to æj, then we can encourage 
the model to make the predicted output at both locations (i.e., f(#;) and f(æ;)) to be similar. 

To define similarity, we introduce the notion of a kernel function. The word “kernel” has many 
different meanings in mathematics; here we consider a Mercer kernel, also called a positive 
definite kernel. This is any symmetric function K : ¥ x X — R?t such that 


N N 
XOY Kai, a, ee; > 0 (18.5) 
i=1 j=1 


for any set of N (unique) points æ; € X, and any choice of numbers c; € R. We assume K(x;,2,;) > 0, 
so that we can only achieve equality in the above equation if c; = 0 for all i. 
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Another way to understand this condition is the following. Given a set of N datapoints, let us 
define the Gram matrix as the following N x N similarity matrix: 


K(z1, £1) © K(a1,aN) 
K= l (18.6) 
K(@np, £1) © Klan, £y) 


We say that K is a Mercer kernel iff the Gram matrix is positive definite for any set of (distinct) 
inputs {a;}%,. 

We discuss several popular Mercer kernels below. More details can be found at [Wil14] and 
https://www.cs.toronto.edu/~duvenaud/cookbook/. See also Section 18.6 where we discuss how 
to learn kernels from data. 


18.2.1 Stationary kernels 


For real-valued inputs, ¥ = R”, it is common to use stationary kernels (also called shift-invariant 
kernels), which are functions of the form K(x,x’) = K(r), where r = x — a’; thus the output 
only depends on the relative difference between the inputs. (See Section 18.2.2 for a discussion 
of non-stationary kernels.) Furthermore, in many cases, all that matters is the magnitude of the 
difference: 


SIR IS IS IS IS le le IS le la le le Is lE ls 


g ie EIS BRIS 


r = ||rll2 = |æ — z’ (18.7) 


23 We give some examples below. (See also Figure 18.3 and Figure 18.4 for some visualizations of these 


kernels.) 


26 18.2.1.1 Squared exponential (RBF) kernel 


— The squared exponential (SE) kernel, also sometimes called the exponentiated quadratic kernel 
~~ or the radial basis function (RBF) kernel, is defined as 


Kiron (-2) ase) 


= Here £ corresponds to the length-scale of the kernel, i.e., the distance over which we expect 


== differences to matter. 


& IS 18 I& Ik 


From Equation (18.7) we can rewrite this kernel as 


_. ppl || 2 
K(x, a’: 0) = exp (E) (18.9) 


22 This is the RBF kernel we encountered earlier. It is also sometimes called the Gaussian kernel. 


ra 
e |O io 


See Figure 18.3(f) and Figure 18.4(f) for a visualization in 1D. 


= 18.2.1.2 ARD kernel 


Sale EES 
N IO lo Te [© |b 


43 We can generalize the RBF kernel by replacing Euclidean distance with Mahalanobis distance, as 


follows: 


K(r; 3,07) = o° exp (=ar) (18.10) 
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fi dmo 


Figure 18.2: Function samples from a GP with an ARD kernel. (a) £1 = 2 = 1. Both dimensions contribute 
to the response. (b) Lı = 1, L2 = 5. The second dimension is essentially ignored. Adapted from Figure 5.1 of 
[RW06]. Generated by gpr_demo_ ard.ipynb. 


where r = x — a’. If X is diagonal, this can be written as 


D D 
1 1 
K(r; £, o°) = 07 exp (-; 5 7) = II K(ra; la, 07/%) (18.11) 
d=1 4 d=1 
where 
et) aes ae (18.12) 
i 7 2 g2 . 


We can interpret g? as the overall variance, and f4 as defining the characteristic length scale 
p 8 


of dimension d. If d is an irrelevant input dimension, we can set f4 = oo, so the corresponding 
dimension will be ignored. This is known as Automatic Relevance Determination or ARD 
(Section 15.2.7). Hence the corresponding kernel is called the ARD kernel. See Figure 18.2 for an 
illustration of some 2d functions sampled from a GP using this prior. 


18.2.1.3 Matern kernels 


The SE kernel gives rise to functions that are infinitely differentiable, and therefore are very smooth. 

For many applications, it is better to use the Matern kernel, which gives rise to “rougher” functions, 

which can better model local “wiggles” without having to make the overall length scale very small. 
The Matern kernel has the following form: 


kin = Fe ( z) x, ( 2) (18.13) 


where K, is a modified Bessel function and £ is the length scale. Functions sampled from this GP 
are k-times differentiable iff v > k. As v > oo, this approaches the SE kernel. 
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Er i k(x,0.0) 
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10 
(c) Matern52 k(x,0.0) 


f 0. 
20 

0 
20 


(i) Linear k(x,1.0) 


—10 0 10 
i ) pou eae. w ee k(x,1.0) (1) White noise k(x,0.0) 
Figure 18.3: GP kernels evaluated at k(x,0) as a function of x. Generated by gpKernelPlot.ipynb. 
For values v € {4, 3, 3}, the function simplifies as follows: 
1 r 
K(r; 5:8) = exp(-7) (18.14) 
3 3r 3r 
5 Sf pate aren aes 18.1 
EE EE com 
5 v5r 5r? 5r 
< en ee i, caer Se 18.16 
engo (14 482) an (28) com 


43 See Figure 18.3(a-c) and Figure 18.4(a-c) for a visualization. 
corresponds to the Ornstein-Uhlenbeck process, which describes the velocity 
45 of a particle undergoing Brownian motion. The corresponding function is continuous but not 
46 differentiable, and hence is very “jagged”. 


The value v = 5 
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—10 0 10 —10 0 10 —10 0 10 
(a) Matern12 (b) Matern32 (c) Matern52 
1 1 2.5 
0 0.0 
0 -1 
—2.5 
—10 0 10 —10 0 10 —10 0 10 
(d) Periodic (e) Cosine (f) RBF 
2.5 14 | 
0.0 ————————— 0 
0 
ns _ — 1 () 
—10 0 10 —10 0 10 10 0 10 
(g) Rational quadratic (h) Constant (i) Linear 
100 5 2.5 
0 0.0 
0 —5 —2.5 
—10 0 10 —10 0 10 —10 0 10 
(j) Quadratic (k) Polynomial (1) White noise 


Figure 18.4: GP samples drawn using different kernels. Generated by gpKernelPlot.ipynb. 


18.2.1.4 Periodic kernels 


One way to create a periodic 1d random function is to map x to the 2d space u(x) = (cos(x), sin(x)), 
and then use an SE kernel in u-space: 


2sin”((a — x’) /2) 
a ) (18.17) 


K(x, x') = exp ( 
which follows since (cos(x) —cos(«’))? + (sin(x) —sin(x'))? = 4 sin?((a—2’)/2). We can generalize this 
by specifying the period p to get the periodic kernel, also called the exp-sine-squared kernel: 


2 
Kyox(rstsp) = exp (~% sim?) (18.18) 


where p is the period and £ is the length scale. See Figure 18.3(d-e) and Figure 18.4(d-e) for a 


visualization. 
A related kernel is the cosine kernel: 


K(r;p) = cos (=) (18.19) 
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18.2.1.5 Rational quadratic kernel 


We define the rational quadratic kernel to be 


r 


2 —a 
Kro(r; £a) = (1 + sa) (18.20) 


We recognize this is proportional to a Student T density. Hence it can be interpreted as a scale 
mixture of SE kernels of different characteristic lengths. In particular, let r = 1/€?, and assume 
T ~ Ga(a, £). Then one can show that 


Kra(r) = [ela 2)Kselrinjar (18.21) 


As a — œ, this reduces to a SE kernel. 
See Figure 18.3(g) and Figure 18.4(g) for a visualization. 


18.2.1.6 Kernels from Spectral Densities 


Consider the case of a stationary kernel which satisfies K(x, a’) = K(6), where 6 = æ — a’, for 
x,a’ € R¢. Let us further assume that K(6) is positive definite. In this case, Bochner’s theorem 
tells us that we can represent K(ô) by its Fourier transform: 


K(6) = 1 ; p(w)er" dw (18.22) 


27 where j = y=], ef? = cos(0) + jsin(@), and p(w) is the spectral density (a distribution over 
28 frequencies). 
We can easily derive and gain intuitions into several kernels from spectral densities. If we take 


the Fourier transform of an RBF kernel we find the spectral density p(w) = V 270? exp (—27?w?0?), 


31 Thus the spectral density is also Gaussian, but with a bandwidth inversely proportional to the 
32 length-scale hyperparameter £. That is, as £ becomes large, the spectral density collapses onto a 
33 point mass. This result is intuitive: as we increase the length-scale, our model treats points as 
34 correlated over large distances, and becomes very smooth and slowly varying, and thus low-frequency. 
35 In general, since the Gaussian distribution has relatively light tails, we can see that RBF kernels 
36 won’t generally support high frequency solutions. 


We can instead use a Student-t spectral density, which has heavy tails that will provide greater 


38 support for higher frequencies. Taking the inverse Fourier transform of this spectral density, we 
39 recover the Matern kernel, with degrees of freedom v corresponding to the degrees of freedom in 
40 the spectral density. Indeed, the smaller we make v, the less smooth and higher frequency are the 
41 associated fits to data using a Matern kernel. 


We can also derive spectral mixture kernels by modelling the spectral density as a scale-location 


43 mixture of Gaussians and taking the inverse Fourier transform [WA13]. Since scale-location mixtures 


of Gaussians are dense in the set of distributions, and can therefore approximate any spectral density, 


45 this kernel can approximate any stationary kernel to arbitrary precision. The spectral mixture kernel 
46 thus forms a powerful approach to kernel learning, which we discuss further in Section 18.6.5. 
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18.2.2 Non-stationary kernels 


A stationary kernel assumes the measure of similarity between two inputs is independent of their 
location, i.e., K(x, x’) only depends on r = x— gr’. A non-stationary kernel relaxes this assumption. 
This is useful for a variety of problems, such as environmental modeling (see e.g., [Pat+22]), where 
correlations between locations can change depending on latent factors in the environment. 


18.2.2.1 Polynomial kernels 


A simple form of non-stationary kernel is the polynomial kernel (also called dot product kernel) 
of order M, defined by 


K(x, a’) = (xx) (18.23) 


This contains all monomials of order M. For example, if M = 2, we get the quadratic kernel; in 
2d, this becomes 


(a! a’)? = (aya, + woah)? = (x121)? + (£282)? + 2(a1 2) (£224) (18.24) 


We can generalize this to contain all terms up to degree M by using the inhomogeneous 
polynomial kernel 


K(a, a’) = (ala’ +c)“ (18.25) 
For example, if M = 2 and the inputs are 2d, we have 


(ala! +1)? = (z121)? + (2121) (w229) + (2121) 
+ (z222)(£121) + (wear)? + (z221) 


+ (2124) + (£223) +1 (18.26) 


18.2.2.2 Gibbs kernel 


Consider an an RBF kernel where the length scale hyper-parameter, and the signal variance hyper- 
parameter, are both input dependent; this is called the Gibbs kernel [Gib97], and is defined by 


K(x, 2") = o(@)o(x") ee aap ( rena) (18.27) 


If L(x) and a(x) are constants, this reduces to the standard RBF kernel. We can model the functional 
dependency of these kernel parameters on the input by using another GP (see e.g., [Hei+16]). 


18.2.2.3 Other non-stationary kernels 


Other ways to induce non-stationarity include using a neural network kernel (Section 18.7.1), non- 
stationary spectral kernels [RHK17], or a deep GP (Section 18.7.3). 
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18.2.3 Kernels for non-vectorial (structured) inputs 


Kernels are particularly useful when the inputs are structured objects, such as strings and graphs, 
since it is often hard to “featurize” variable-sized inputs. For example, we can define a string kernel 
which compares strings in terms of the number of n-grams they have in common [Lod-+02; BC17]. 

We can also define kernels on graphs [KJM19]. For example, the random walk kernel conceptually 
performs random walks on two graphs simultaneously, and then counts the number of paths that 
were produced by both walks. This can be computed efficiently as discussed in [Vis+10]. For more 
details on graph kernels, see [KJM19]. 

For a review of kernels on structured objects, see e.g., [Gar03]. 


18.2.4 Making new kernels from old 


Given two valid kernels K: (æ, x’) and K2(a, a’), we can create a new kernel using any of the following 
methods: 


K(a, x’) = cK,(a,2’), for any constant c > 0 (18.28) 
K(x, x) = f(x)Ki(x, x") f(x’), for any function f (18.29) 
K(x, ax’) = (Kı (x, x')) for any function polynomial q with nonneg. coef. (18.30) 
K(x, x’) = exp(K1(a, 2’)) (18.31) 
K(a, a’) = a' Aa’, for any psd matrix A (18.32) 


For example, suppose we start with the linear kernel K(x, a’) = aa’. We know this is a valid 
Mercer kernel, since the corresponding Gram matrix is just the (scaled) covariance matrix of the data. 
From the above rules, we can see that the polynomial kernel K(x, x’) = (a'a’)™ from Section 18.2.2.1 


= is a valid Mercer kernel. 


We can also use the above rules to establish that the Gaussian kernel is a valid kernel. To see this, 


= note that 


N JIN IN IN JIN JIN JIN JN |W JN 
s le N iela e ISIS IS IS le le S le la le le Is E Is 


e ge] 
a |ó 


|æ — x'||? = ala + (x) x — Qala! (18.33) 


~~ and hence 


A Je Jẹ Jè Jẹ Jẹ Je TR Jw [ew jœ jœ jw j% jw jw 


K(a, x’) = exp(—||a — a’ ||?/207) = exp(—a! x /207) exp(x' a’ /o?) exp(—(a’)'a’/207) (18.34) 


is a valid kernel. 
We can also combine kernels using addition or multiplication: 


K(a, x’) = Ki(a, 2’) + Ko(ax, x’) (18.35) 
K(x, x) = K(x, 2") x Kalz, 2’) (18.36) 


Multiplying two positive-definite kernels together always results in another positive definite kernel. 


41 This is a way to get a conjunction of the individual properties of each kernel, as illustrated in 
42 Figure 18.5. 


In addition, adding two positive-definite kernels together always results in another positive definite 
kernel. This is a way to get a disjunction of the individual properties of each kernel, as illustrated in 


45 Figure 18.6. 


For an example of combining kernels to forecast some timeseries data, see Section 18.8.1. 
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Figure 18.5: Examples of 1d structures obtained by multiplying elementary kernels. Top row shows K(x, x’ = 1). 
Bottom row shows some functions sampled from GP(f|0,K). Adapted from Figure 2.2 of [Duv14]. Generated 
by combining _kernels_ by _ multiplication.ipynb. 
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Figure 18.6: Examples of 1d structures obtained by summing elementary kernels. Top row shows K(x, x’ = 1). 
Bottom row shows some functions sampled from GP(f|0,K). Adapted from Figure 2.2 of [Duv14]. Generated 
by combining kernels_ by _ summation.ipynb. 


18.2.5 Mercer’s theorem 


Recall that any positive definite matrix K can be represented using an eigendecomposition of the 
form K = U'AU, where A is a diagonal matrix of eigenvalues A; > 0, and U is a matrix containing 
the eigenvectors. Now consider element (i, j) of K: 


ki; = (A2U,)"(A?U.;) (18.37) 
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where U.; is the tth column of U. If we define (x) = U.;, then we can write 


M 


m=1 


where M is the rank of the kernel matrix. Thus we see that the entries in the kernel matrix can be 
computed by performing an inner product of some feature vectors that are implicitly defined by the 
eigenvectors of the kernel matrix. 

This idea can be generalized to apply to kernel functions, not just kernel matrices, as we now show. 
First, we define an eigenfunction ¢() of a kernel K with eigenvalue À wrt measure p as a function 
that satisfies 


l K(æ, x')ġ(æ)du(æ) = Agla") (18.39) 


We usually sort the eigenfunctions in order of decreasing eigenvalue, Ay > Ap > ---. The eigenfunctions 
are orthogonal wrt pu: 


/ bila); (a)du(a) = Sij (18.40) 


where 6;; is the Kronecker delta. With this definition in hand, we can state Mercer’s theorem. 
Informally, it says that any positive definite kernel function can be represented as the following 
infinite sum: 


K(a,2') = X Amm(a#)bm(2’) (18.41) 


m=1 


28 where ọm are eigenfunctions of the kernel, and Am are the corresponding eigenvalues. This is the 
23 functional analog of Equation (18.38). 


A degenerate kernel has only a finite number of non-zero eigenvalues. In this case, we can 


°° rewrite the kernel function as an inner product between two finite-length vectors. For example, 
32 consider the quadratic kernel K(x, x’) = (a, x’)? from Equation (18.24). If we define (a1, £2) = 
33 [x?, \/22122, £2] € R, then we can write this as K(x, x’) = o(x)'o(a). Thus we see that this kernel 
34 is degenerate. 


Now consider the RBF kernel. In this case, the corresponding feature representation is infinite 


36 dimensional (see Section 18.2.6 for details). However, by working with kernel functions, we can avoid 
°° having to deal with infinite dimensional vectors. 


From the anove, we see that we can replace inner product operations in an explicit (possibly infinite 


39 dimensional) feature space with a call to a kernel function, i.e., we replace (x)' d(x) with K(x, 2’). 


20 This is called the kernel trick. 


A Je Jẹ Je Ae Je Je 


— 18.2.6 Approximating kernels with random features 


Although the power of kernels resides in the ability to avoid working with featurized representations 


45 of the inputs, such kernelized methods can take O(N*) time, in order to invert the Gram matrix 
46 K, as we wil see in Section 18.3. This can make it difficult to use such methods on large scale data. 
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Fortunately, we can approximate the feature map for many kernels using a randomly chosen finite 
set of M basis functions, thus reducing the cost to O(NM + M°). 

We will show how to do this for shift-invariant kernels by returning to Bochner’s theorem in 
Eq. (18.22). In the case of a Gaussian RBF kernel, we have seen that the spectral density is a 
Gaussian distribution. Hence we can easily compute a Monte Carlo approximation to this integral by 
sampling random Gaussian vectors. This yields the following approximation: K(æ, æ) ~ o(æ)" (x), 
where the (real-valued) feature vector is given by 


p(x) = E [sin(z] 2), -++ ,sin(zpax), cos(z{ æ), +- ,cos(zpæ)]| (18.42) 
= E [sin(Z" æ), cos(Z'æ)] (18.43) 


Here Z = (1/c)G, and G € R?*P is a random Gaussian matrix, where the entries are sampled 
iid from M (0,1). The representation in Equation (18.43) are called random Fourier features 
(RFF) [RR08] or “weighted sums of random kitchen sinks” [RR09]. (One can obtain an even better 
approximation by ensuring that the rows of Z are random but orthogonal; this is called orthogonal 
random features [Yu+16].) 

One can create similar random feature representations for other kinds of kernels. We can then 
use such features for supervised learning by defining f(x;@) = Wy(Za) + b, where Z is a random 
Gaussian matrix, and the form of p depends on the chosen kernel. This is equivalent to a one layer 
MLP with random input-to-hidden weights; since we only optimize the hidden-to-output weights 
6 = (W, b), the model is equivalent to a linear model with fixed random features. If we use enough 
random features, we can approximate the performance of a kernelized prediction model, but the 
computational cost is now O(N) rather than O(N7?). 

Unfortunately, random features can result in worse performance than using a non-degenerate 
kernel, since they don’t have enough expressive power. We discuss other ways to scale GPs to large 
datasets in Section 18.5. 


18.3 GPs with Gaussian likelihoods 


In this section, we discuss GPs for regression, using a Gaussian likelihood. In this case, all the 
computations can be performed in closed form, using standard linear algebra methods. We extend 
this framework to non-Gaussian likelihoods later in the chapter. 


18.3.1 Predictions using noise-free observations 


Suppose we observe a training set D = {(@n, Yn) : n = 1 : N}, where yn = f(x») is the noise-free 
observation of the function evaluated at æn. If we ask the GP to predict f(a) for a value of æ that it 
has already seen, we want the GP to return the answer f(x) with no uncertainty. In other words, it 
should act as an interpolator of the training data. Here we assume the observed function values 
are noiseless. We will consider the case of noisy observations shortly. 

Now we consider the case of predicting the outputs for new inputs that may not be in D. Specifically, 
given a test set X, of size N, x D, we want to predict the function outputs f, = [f(a1),...,f(an,)]- 
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Figure 18.7: Left: some functions sampled from a GP prior with RBF kernel. Middle: some samples from 
a GP posterior, after conditioning on 5 noise-free observations. Right: some samples from a GP posterior, 
after conditioning on 5 noisy observations. The shaded area represents E[f(a)| + 2,/V[f(x)]. Adapted from 
Figure 2.2 of [RW06]. Generated by gpr_demo_ noise_free.ipynb. 


By definition of the GP, the joint distribution p(fx, f.|X,X.) has the following form 


(F) ~w (4x), (Tees ee) (18.44) 


where uy = (m(#1),...,m(@wp)), Me = (m(xj),...,m(ay,)), Kx,x = K(X, X) is Np x Np, 
Kx». = K(X, X.) is Np x N,, and K,.. = K(X., X4) is N. x Ns. See Figure 18.7 for a static 
illustration, and http://www.infinitecuriosity.org/vizgp/ for an interactive visualization. 

By the standard rules for conditioning Gaussians (Section 2.2.5.4), the posterior has the following 
form 


PFX D) = N (Falha Xxx) (18.45 
hax = h, + KK Ky'x(fx — bx) (18.46 
Zax = Kix — KK Kx. (18.47 


SS Aa se 


This process is illustrated in Figure 18.7. On the left we show some samples from the prior, p(f), 
where we use an RBF kernel (Section 18.2.1.1) and a zero mean function. On the right, we show 
samples from the posterior, p(f|D). We see that the model perfectly interpolates the training data, 
and that the predictive uncertainty increases as we move further away from the observed data. 

Note that the cost of the above method for sampling N. points is O(N3). This can be reduce to 
O(N.) time using the methods in [Ple+18; Wil+ 20a]. 


——, 


18.3.2 Predictions using noisy observations 


In Section 18.3.1, we showed how to do GP regression when the training data was noiseless. Now let us 
consider the case where what we observe is a noisy version of the underlying function, yn = f(a@n)+€n, 
where en ~ N(0, oy) In this case, the model is not required to interpolate the data, but it must 
come “close” to the observed data. The covariance of the observed noisy responses is 


Cov [Yas yj] = Cov [fis fil + Cov lei, €j] = K(zi, zj) + o ij (18.48) 
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where ði; = I (i = j). In other words 
Cov [y|X] =Kxyx+ o In (18.49) 


The joint density of the observed data and the latent, noise-free function on the test points is given 


by 
y) uy ( (Hx) (Kxxt oul a) 18.50 
(oa (C) Oa ae asao) 


Hence the posterior predictive density at a set of test points X, is 


PFD, Xx) = N (Fha Xxx) (18.51) 
hax = h, + Ky a(Kx,x +02) (y -— uy) (18.52) 
Dajx = Ky. — Kk,,(Kx,x +I) Kx, (18.53) 


In the case of a single test input, this simplifies as follows 
PFD, £4) = N (fala + ki (Kx,x +021) (y — bx), ks — ki (Kx,x + 02I) tk.) (18.54) 


where k, = [K (a, £1), ...,K(£4,£N)] and ky. = K(£4, £4). If the mean function is zero, we can 
write the posterior mean as follows: 


N 
Hax = kI Key = X Kae, £n)an (18.55) 
=e WHT 
where 
K, = Kxxx + oil (18.56) 
a=Ky!ly (18.57) 


Fitting this model amounts to computing @ in Equation (18.57). This is usually done by computing 
the Cholesky decomposition of K,, as described in Section 18.3.6. Once we have computed a, we 
can compute predictions for each test point in O(N) time for the mean, and O(N?) time for the 
variance. 


18.3.3 Weight space vs function space 


In this section, we show how Bayesian linear regression is a special case of a GP. 

Consider the linear regression model y = f(x) + e, where f(x) = w' (a) and e ~ N (0, a2). If we 
use a Gaussian prior p(w) = N (w|0, £), then the posterior is as follows (see Section 15.2.1 for the 
derivation): 


p(wlD) =N (w| A718? y, A) (18.58) 
y 
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where ® is the N x D design matrix, and 
A=0,°8'S4+5;' (18.59) 


The posterior predictive distribution for f, = f(a.) is therefore 
1 = = 
P(fe[D, ax) =N (fel GOA By, HAT O,) (18.60) 
y 


where @, = p(x). This views the problem of inference and prediction in weight space. 

We now show that this is equivalent to the predictions made by a GP using a kernel of the form 
K(x, x’) = (£) £u p(x’). To see this, let K = ®D,,6', k, = ®D,,¢,, and kss = dL Eup, Using 
this notation, and the matrix inversion lemma, we can rewrite Equation (18.60) as follows 


Hax = P Dw" (K + ofl) ty = ki (Ky,x + oyl) ty (18.62) 

Dax = PLLng, — Eud (K +I Enp, = ke — kl (Kx,x +071) tk, (18.63) 

which matches the results in Equation (18.54), assuming m(x) = 0. A non-zero mean can be captured 
by adding a constant feature with value 1 to (æ). 

Thus we can derive a GP from Bayesian linear regression. Note, however, that linear regression 

assumes @(2) is a finite length vector, whereas a GP allows us to work directly in terms of kernels, 


which may correspond to infinite length feature vectors (see Section 18.2.5). That is, a GP works in 
function space. 


18.3.4 Semi-parametric GPs 


= So far, we have mostly assumed the mean of the GP is 0, and have relied on its interpolation abilities 
= to model the mean function. Sometimes it is useful to fit a global linear model for the mean, and use 
= the GP to model the residual errors, as follows: 


g(a) = f(x) + B" (a) (18.64) 


where f(x) ~ GP(0,K (a, x’)), and #() are some fixed basis functions. This combines a parametric 


and a non-parametric model, and is known as a semi-parametric model. 

If we assume 3 ~ N (b, B), we can integrate these parameters out to get a new GP [O’H78): 

g(a) ~ GP ($(a)"b, K(z, a") + $(2)"Bé(2')) (18.65) 
38 Let Hy = (X)! be the D x N matrix of training examples, and H, = ¢(X.)' be the D x N, matrix 
39 of test examples. The corresponding predictive distribution for test inputs X, has the following form 
40 [RW06, p28]: 

[9(X.)|D] = HL + K\,,.Kz"(y — HX) =E[f(X.)|D] + R'8 18.66 
Cov [g(X.)|D] = Cov [f(X.)|D] + R! (B7! + Hy K;'HY) 'R 18.67 


B=(B1+HxK;'H))-1(HxK;'y+B'b) 
R= H, —- HxK,;'Kx. 


oO 


18.68 


) 
) 
) 
18.69) 


( 
( 
( 
( 
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These results can be interpreted as follows: the mean is the usual mean from the GP , plus a global 
offset from the linear model, using 3; and the covariance is the usual covariance from the GP, plus 
an additional positive term due to the uncertainty in 8. 

In the limit of an uninformative prior for the regression parameters, as B — ool, this simplifies to 


t [g(X.)|D] = E[f(X.)|D] + R' (Ax K7 'Hy) Hx Kz'y (18.70) 
Cov [g(X..)|D] = Cov [f(X D] + R'(HxK>'H)) 'R (18.71) 


18.3.5 Marginal likelihood 


Most kernels have some free parameters. For example, the RBF-ARD kernel (Section 18.2.1.2) has 
the form 


D 


D 
ee ET (-4 > zlea z zo) = J| Ku (za 2h) (18.72) 
d=1 


d=1 4 


where each %4 is a length scale for feature dimension d. Let these (and the observation noise variance 
on if present) be denoted by 0. We can compute the likelihood of these parameters as follows: 


p(y|X, 8) = p(D|@) = J p(ylfx,)p(Fx 1X, 0)dfx (18.73) 


Since we are integrating out the function f, we often call O hyperparameters, and the quantity p(D|@) 
the marginal likelihood. 

Since f is a GP, we can compute the above integral using the marginal likelihood for the corre- 
sponding Gaussian. This gives 


1 7 1 N 
log p(D|9) = -50 - ux)'K3 (y — ux) = 510g [Kol — 5 log(2n) (18.74) 


The first term term is the square of the Mahalanobis distance between the observations and the 
predicted values: better fits will have smaller distance. The second term is the log determinant 
of the covariance matrix, which measures model complexity: smoother functions will have smaller 
determinants, so — log |K, | will be larger (less negative) for simpler functions. The marginal likelihood 
measures the tradeoff between fit and complexity. 

In Section 18.6.1, we discuss how to learn the kernel parameters from data by maximizing the 
marginal likelihood wrt 0. 


18.3.6 Computational and numerical issues 


In this section, we discuss computational and numerical issues which arise when implementing the 
above equations. For notational simplicity, we assume the prior mean is zero, m(x) = 0. 

The posterior predictive mean is given by p, =k! K>!y. For reasons of numerical stability, it is 
unwise to directly invert K,. A more robust alternative is to compute a Cholesky decomposition, 
K, = LL’, which takes O(N) time. Given this, we can compute 


us =k K>ly=kIL'(L-ty) = kla (18.75) 
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Here a = L' \ (L \ y), where we have used the backslash operator to represent backsubstitution. 
We can compute the variance in O(N?) time for each test case using 


o? = hee — REL IL ky = kys — 00 (18.76) 


where v = L \ ky. 
Finally, the log marginal likelihood (needed for kernel learning, Section 18.6) can be computed 
using 


N 
1 N 
log p(y|X) = -y'a — X log Lnn — 5 log(27) (18.77) 


n=1 


We see that overall cost is dominated by O(N®). We discuss faster, but approximate, methods in 
Section 18.5. 


18.3.7 Kernel ridge regression 
The term ridge regression refers to linear regression with an 2 penalty on the regression weights: 
N 
w* = argmin X` (yn — f (£n; w))? + Allwl|3 (18.78) 


w n=1 


where f(x;w) = w'z. The solution for this is 
w* = (X'X +d) XTy 2 Enel + AL)~ È LnYn) (18.79) 
In this section, we consider a function space version of this: 


fi= eo 24 MIFII? (18.80) 


JEF 


n=1 


36 For this to make sense, we have to define the function space F and the norm ||f||. If we use a 
37 function space derived from a positive definite kernel function K, the resulting method is called 
38 kernel ridge regression (KRR). We will see that the resulting estimate f*(ax.,.) is equivalent to 
39 the posterior mean of a GP. We give the details below. 


— 18.3.7.1 Reproducing kernel Hilbert spaces 


43 In this section, we briefly introduce the relevant mathematical “machinery” needed to explain KRR. 


Let F = {f : X > R} be a space of real-valued funcitons. Elements of this space (i.e., functions) 


45 can be added and scalar multiplied as if they were vectors. That is, if f € F and g € F, then 
46 af +g € F for a, ß € R. We can also define an inner product for F, which is a mapping (f, g) € R 
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which satisfies the following: 


(afi + Bf2,9) = a(fi,g) + B(f2,9) (18.81) 
(f,9) = A f) (18.82) 
(Ff) 2 (18.83) 
(f, f) =0 iff f(x) = 0 for all x eX (18.84) 


We define the norm of a function using 


ILE V, P) (18.85) 


A function space H with an inner product operator is called a Hilbert space. (We also require 
that the function space be complete, which means that every Cauchy sequence of functions f; € H 
has a limit that is also in H.) 

The most common Hilbert space is the space known as L?. To define this, we need to specify a 
Measure p on the input space 4; this is a function that assigns any (suitable) subset A of Æ toa 
positive number, such as its volume. This can be defined in terms of the density function w: ¥ > R, 
as follows: 


ayz [ooa (18.86) 


Thus we have (dx) = w(x)dx. We can now define L? (X, u) to be the space of functions f : ¥ + R 
that satisfy 


fto f(x) w(xr)dz < co (18.87) 


This is known as the set of square-integrable functions. This space has an inner product defined 
by 


= r; Toi (18.88) 
Æ 


We define a Reproducing Kernel Hilbert Space or RKHS as follows. Let H be a Hilbert 
space of functions f : ¥ — R. We say that H is an RKHS endowed with inner product (-,-), if there 
exists a (symmetric) kernel function K : ¥ x X > R with the following properties: 

e For every x € X, K(x,- ) E€ H. 
e K satisfies the reproducing property: 


The reason for the term “reproducing property” is as follows. Let f(-) = K(a,-). Then we have that 
(K(a,-), K(-, 2") T K(x, x") (18.90) 
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18.3.7.2 Complexity of a function in an RKHS 


The main utility of RKHS from the point of view of machine learning is that it allows us to define a 
notion of a function’s “smoothness” or “complexity” in terms of its norm, as we now discuss. 

Suppose we have a positive definite kernel function K. From Mercer’s theorem we have K(æ, x’) = 
Se Agilæ)oilx’). Now consider a Hilbert space H defined by functions of the form f(æ) = 
A Jile), with 77°, f?/A < oo. The inner product of two functions in this space is 


(gu =>) fa. (18.91) 
i=1 “* 
Hence the (squared) norm is given by 
2 Sf 
R= Pad (18.92) 
i=1 ° 


This is analogous to the quadratic form f'K~!f which occurs in some GP objectives (see Equa- 
tion (18.101)). Thus the smoothness of the function is controlled by the properties of the corresponding 
kernel. 


18.3.7.3 Representer theorem 
In this section, we consider the problem of (regularized) empirical risk minimization in function space. 
In particular, consider the following problem: 


N 
f* = argmin (un, Fen) + ŽIG (18.98) 


feEHK n=1 


where Hx is an RKHS with kernel K and L(y, ĝ) € R is a loss function. Then one can show [KW70; 
SHSO1] the following result: 


N 
P@ => an KE, £n) (18.94) 


= where a, € R are some coefficients that depend on the training data. This is called the representer 
= theorem. 


w lw J|% 
Q o e 


Now consider the special case where the loss function is squared loss, and À = ap We want to 


= minimize 


A [è 1& IÈ l& iS [e [8 18 le IS 


N 
LA) = say Solum — Flen)? + SILI (18.95) 


2 
Y n=1 


“u Substituting in Equation (18.94), and using the fact that (K(-,2;),K(-,#,;)) = K(a;,2,;), we obtain 


at 1 2 
L(f)=32 a ane) (18.96) 
on Les E ira 
= K4 K K —— 18. 
z2 ( ES Ja al a+ 22” y (18.97) 
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4 e data 
—— True 

2 e e e —— KRR 
o$ e of ” A — GPR 


target 


0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0 
data 


Figure 18.8: Kernel ridge regression (KRR) compared to Gaussian process regression (GPR) using the same 
kernel. Generated by krr_vs_gpr.ipynb. 


Minimizing this wrt œ gives & = (K + 071)~'y, which is the same as Equation (18.57). Furthermore, 
the prediction for a test point is 


f(a.) = kla = k! (K + oI) y (18.98) 


This is known as kernel ridge regression [Vov13]. We see that the result matches the posterior 
predictive mean of a GP in Equation (18.55). 


18.3.7.4 Example of KRR vs GPR 


In this section, we compare KRR with GP regression on a simple 1d problem. Since the underlying 
function is believed to be periodic, we use the periodic kernel from Equation (18.18). To capture the 
fact that the observations are noisy, we add to this a white noise kernel 


K(a, 2’) = 026(@ — 2’) (18.99) 


as in Equation (18.48). Thus there are 3 GP hyper-parameters: the kernel length scale £, the kernel 
periodicity p, and the noise level ož. We can optimize these by maximizing the marginal likelihood 
using gradient descent (see Section 18.6.1). For KRR, we also have 3 hyper-parameters (4, p and 
A= o2); we optimize these using grid search combined with cross validation (which in general is 
slower than gradient based optimization). The resulting model fits are shown in Figure 18.8, and are 
very similar, as is to be expected. 


18.4 GPs with non-Gaussian likelihoods 


So far, we have focused on GPs for regression using Gaussian likelihoods. In this case, the posterior 
is also a GP, and all computation can be performed analytically. However, if the likelihood is 
non-Gaussian, we can no longer compute the posterior exactly. We can create variety of different 
“classical” models by changing the form of the likelihood, as we show in Table 18.1. In the sections 
below, we briefly discuss some approximate inference methods. (For more details, see e.g., [WSS21].) 
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Model Likelihood Section 

Regression N (fi, oz) Section 18.3.2 
Robust regression Tol fiso) Section 18.4.4 
Binary classification Ber(a(fi)) Section 18.4.1 
Multiclass classification Cat(softmax(f,;)) Section 18.4.2 
Poisson regression Poi(exp(fi)) Section 18.4.3 


Table 18.1: Summary of GP models with a variety of likelihoods. 


2 
log p(yilfi) | az log p(uilfi) | log plwil fi) 
log a(yifi) ti — Ti | —mi(1 — Ti) 
ipl fi) $i i fib( fi) 
log (ui fi) | Sty.) Bit? Edi 


Table 18.2: Likelihood, gradient and Hessian for binary logistic/ probit GP regression. We assume yi € 
{—1, +1} and define ti = (yi + 1)/2 € {0,1} and mi = o(fi) for logistic regression, and mi = ®( fi) for probit 
regression. Also, @ and ® are the pdf and cdf of N(0,1). From [RW06, p48]. 


18.4.1 Binary classification 


In this section, we consider binary classification using GPs. If we use the sigmoid link function, we 
have p(Yn = 1|£n) = o(Ynf(an)). If we assume yn E€ {—1, +1}, then we have p(yn|@n) = o (Yn fn), 
since 0(—z) = 1—o(z). If we use the probit link, we have p(yn = 1|£n) = ®(Ynf(an)), where (z) is 
the cdf of the standard normal. More generally, let p(yn|@n) = Ber(yn|y(fn)). The overall log joint 
has the form 


L(fx) = log p(ylfx) + log p(fx|X) (18.100) 


1 = 1 N 
= log p(ylfx) — gfxKx xfx 5 log |Kx,x| — 5 log 27 (18.101) 


The simplest approach to approximate inference is to use a Laplace approximation (Section 7.4.3). 


32 The gradient and Hessian of the log joint are given by 


VL = Vlog p(ylfx) — Ky'xfx (18.102) 
V?L = V’ log p(y|fx) — Kx!x = -A -Kyy (18.103) 


37 where A = —V? log p(y|fx) is a diagonal matrix, since the likelihood factorizes across examples. 
238 Expressions for the gradient and Hessian of the log likelihood for the logit and probit case are shown 
32 in Table 18.2. At convergence, the Laplace approximation of the posterior takes the following form: 


P(Fx|D) © a( fx) =N(F, (Kx'x + A)7) (18.104) 


43 where f is the MAP estimate. See [RW06, Sec 3.4] for further details. 


For improved accuracy, we can use variational inference, in which we assume q( fx) = N (fx|m, S); 


45 we then optimize m and S using (stochastic) gradient descent, rather than assuming S is the Hessian 
46 at the mode. See Section 18.5.4 for the details. 
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3.16**2 * RBF(length_scale=0.5) 4.79**2 * RBF(length_scale=1.19) 


(a) (b) 


Figure 18.9: Contours of the posterior predictive probability for a binary classifier generated by a GP with an 
SE kernel. (a) Manual kernel parameters: short length scale, £ = 0.5, variance 3.16° ~ 9.98. (b) Learned 
kernel parameters: long length scale, £ = 1.19, variance 4.79? ~ 22.9. Generated by gpc_demo_ 2d.ipynb. 


Once we have a Gaussian posterior g(fx|D), we can then use standard GP prediction to compute 
q(fx|"+,D). Finally, we can approximate the posterior predictive distribution over binary labels 
using 


te = p(s = Late, D) = / plus = Ufe)a(felars, D)df. (18.105) 


This 1d integral can be computed using the probit approximation from Section 15.3.5. In this case 
we have 7, ~ o(K(v)E[f.]), where v = V[f,] and «?(v) = (1+ mv/8)71. 

In Figure 18.9, we show a synthetic binary classification problem in 2d. We use an SE kernel. 
On the left, we show predictions using hyper-parameters set by hand; we use a short length scale, 
hence the very sharp turns in the decision boundary. On the right, we show the predictions using the 
learned hyper-parameters; the model favors more parsimonious explanation of the data. 


18.4.2 Multi-class classification 


The multi-class case is somewhat harder, since the function now needs to return a vector of C logits 
to get p(Yn|an) = Cat(yn|softmax(f,)), where fr = (f},..., f°), It is standard to assume that 
f° ~ GP(0, Ke). Thus we have one latent function per class, which are a priori independent, and 
which may use different kernels. 

We can derive a Laplace approximation for this model as discussed in [RW06, Sec 3.5]. Alternatively, 
we can use a variational approach, using the local variational bound to the multinomial softmax in 
[Chal2]. An alternative variational method, based on data augmentation with auxiliary variables, is 
described in [Wen+19b; Liu+19a; GFWO20]. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Jo Ie lo IN Ie 


IR IRIS IW IR IS Ie Ie IR le lale le Is IE ls 


664 


@ data @ data 15 © data 
7.5 —— true rate = -= true rate = = true rate 

10 = MCMC inferred rate = VI inferred rate 
n n e 95% Confidence n 10 95% Confidence 
2 2 2 
g S g 
2 2 z 
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o (S) o 

(a) (b) (c) 


Figure 18.10: Poisson regression with a GP. (a) Observed data (black dots) and true log rate function (yellow 
line). (b) Posterior predictive distribution (shading shows 1 and 2 o bands) from MCMC. (c) Posterior 
predictive distribution from SVI. Generated by gp_poisson_1d.ipynb. 


18.4.3 GPs for Poisson regression (Cox process) 


In this section, we illustrate Poisson regression where the underlying log rate function is modeled by 
a GP. This is known as a Cox process. We can perform approximate posterior inference in this 
model using Laplace, MCMC or SVI (stochastic variational inference). In Figure 18.10 we give a ld 
example, where we use a Matern 3 kernel. We apply MCMC and SVI. In the VI case, we additionally 
have to specify the form of the posterior; we use a Gaussian approximation for the variational GP 
posterior p(f|X,y), and a point estimate for the kernel parameters. 

An interesting application of this is to spatial disease mapping. For example, [VPV10] discuss 


28 the problem of modeling the relative risk of heart attack in different regions in Finland. The data 


consists of the heart attacks in Finland from 1996-2000 aggregated into 20km x 20km lattice cells. 
The likelihood has the following form: yn ~ Poi(enrn), where en is the known expected number of 


31 deaths (related to the population of cell n and the overall death rate), and rn is the relative risk of 


cell n which we want to infer. Since the data counts are small, we regularize the problem by sharing 


33 information with spatial neighbors. Hence we assume f £ log(r) ~ GP(0,K). We use a Matern 


kernel (Section 18.2.1.3) with v = 3/2, and a length scale and magnitude that are estimated from 
data. 

Figure 18.11 gives an example of this method in action (using Laplace approximation). On the left 
we plot the posterior mean relative risk (RR), and on the right, the posterior variance. We see that 


38 the RR is higher in Eastern Finland, which is consistent with other studies. We also see that the 


variance in the North is higher, since there are fewer people living there. 


42 18.4.4 Other likelihoods 


Many other likelihoods are possible. For example, [VJV09] uses a Student-t likelihood in order to 


45 perform robust regression. A general method for performing approximate variational inference in 
46 GPs with such non-conjugate likelihoods is discussed in [WSS21]. 
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Figure 18.11: We show the relative risk of heart disease in Finland using a Poisson GP fit to 911 data points. 
Left: posterior mean. Right: posterior variance. Generated by gp_spatial_demo.ipynb. 


Method Cost Section 
Cholesky O(N?) Section 18.3.6 
Conj. Grad. O(CN?) Section 18.5.5 
Inducing O(NM? + M? + DNM) Section 18.5.3 
Variational O(NM?° + M? + DNM) Section 18.5.4 
SVGP O(BM? + M? + DNM) Section 18.5.4.3 
KISS-GP O(CN + CDM” log M) Section 18.5.5.3 
SKIP O(DLN + DLM log M + L'N log D + CL?N) Section 18.5.5.3 


Table 18.3: Summary of time to compute the log marginal likelihood of a GP regression model. Notation: N 
is number of training examples, M is number of inducing points, B is size of minibatch, D is dimensionality 
of input vectors (assuming X = R?” ), C is number of conjugate gradient iterations. L is number of Lanczos 
iterations. Based on Table 2 of [Gar+18a]. 


18.5 Scaling GP inference to large datasets 


In Section 18.3.6, we saw that the best way to perform GP inference and training is to compute a 
Cholesky decomposition of the N x N Gram matrix. Unfortunately, this takes O(N?) time. In this 
section, we discuss methods to scale up GPs to handle large N. See Table 18.3 for a summary, and 
[Liu+20c] for more details. 1 


18.5.1 Subset of data 


The simplest approach to speeding up GP inference is to throw away some of the data. Suppose we 
keep a subset of M examples. In this case, exact inference will take O(M*) time. This is called the 


1. We focus on efficient methods for evaluating the marginal likelihood and the posterior predictive distribution. For 
an efficient method for sampling a function from the posterior, see [Wil+-20a]. 
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susbet-of-data approach. 

The key question is: how should we choose the subset? The simplest approach is to pick random 
examples (this method was recently analysed in [HIY19]). However, intuitively it makes more sense 
to try to pick a subset that in some sense “covers” the original data, so it contains approximately 
the same information (up to some tolerance) without the redundancy. Clustering algorithms are 
one heuristic approach, but we can also use coreset methods, which can provably find such an 
information-preserving subset (see e.g., [Hug+19] for an application of this idea to GPs). 


18.5.1.1 Informative vector machine 


Clustering and coreset methods are unsupervised, in that they only look at the features x; and 
not the labels y;, which can be suboptimal. The informative vector machine [HLS03] uses a 
greedy strategy to iteratively add the labeled example (x,;,y,;) that maximally reduces the entropy 
of the function’s posterior, A; = H(p(f;)) — H (p™°™ (f;)), where p®™°™ (fj) is the posterior of f 
at æj after conditioning on yj. (This is very similar to active learning.) To compute Aj, let 
(Fj) = N (uj, vi), and p( flys) « PGN (ysl fj, 07) =N (fy uae”, vpe), where (vze")—! = v7 * +077. 
Since H (N (u, v)) = log(2mev) /2, we have A; = 0.5log(1+v,;/o7). Since this is a monotonic function 
of vj, we can maximize it by choosing the site with the largest variance. (In fact, entropy is a 
submodular function, so we can use submodular optimization algorithms to improve on the IVM, as 
shown in [Kra+08].) 


18.5.1.2 Discussion 


The main problem with the subset of data approach is that it ignores some of the data, which can 
reduce predictive accuracy and increase uncertainty about the true function. Fortunately there 
are other scalable methods that avoid this problem, essentially by approximately representing (or 
compressing) the training data, as we discuss below. 


30 18.5.2 Nyström approximation 


35 Suppose we had a rank M approximation to the N x N matrix gram matrix of the following form: 


Kx, x ~ UAU' (18.106) 


— where A is a diagonal matrix of the M leading eigenvalues, and U is the matrix of the corresponding 


ING Ià IÈ lè IS IE 1S 1S lg |S 18 


— M eigenvectors, each of size N. In this case, we can use the matrix inversion lemma to write 


Kz! =(Kx,x +0°Iy) | yo 7In +o 7U(o? A + UU) tU" (18.107) 


40 which takes O(N M?) time. Similarly, one can show (using the Sylvester determinant lemma) that 


|K,| ~ |A||ļo? At + UTU] (18.108) 


which also takes O(N M7?) time. 
Unfortunately, directly computing such an eigendecomposition takes O(N*) time, which does not 


46 help. However, suppose we pick a subset Z of M < N points. We can partition the Gram matrix as 
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follows (where we assume the chosen points come first, and then the remaining points): 


(18.109) 


K Kz x- 
Ker = ZZ Z X-Z ) 


Kx-zz Kx-zx-z 


Let Kz, x denote the top M x N block, and Kx,z denote its transpose. We now compute an 
eigendecomposition of Kz z to get the eigenvalues {\;}M, and eigenvectors {u;}}4;. We now use 
these to approximate the full matrix as shown below, where the scaling constants are chosen so that 


-AN 
i = Ai 18.11 
di = TA (18.110) 
_, [M1 
Ea AR eai 18.111 
u NA, X,ZU (18.111) 
M 7 
Ky x ~) Att] (18.112) 
i=l 
MN M1 M 1 
= a Wy EK i 4f/——ul Kh 18.113 
LM “VN 7M VN XOX vey 
M 1 
=Kxz (>: dwt Kz,x (18.114) 
= Kx zK33Kz,x (18.115) 


This is known as the Nyström approximation [WS01]. If we define 
Qa.p ê Ka,zK373KZ,B (18.116) 


then we can write the approximate Gram matrix as Qx,x. We can then replace K, with Q XxX= 
Qx,x + 07Iy. Computing the eigendecomposition takes O(M*) time, and computing Qx'x takes 
O(N M7?) time. Thus complexity is now linear in N instead of cubic. 

If we are approximating only Kx,x in H,,x in Equation (18.52) and &,)x in Equation (18.53) 
Q x,x, then this is inconsistent with the other un-approximated kernel function evaluations in these 
formulae, and can result in the predictive variance being negative. One solution to this is to use the 
same Q approximation for all terms. 


18.5.3 Inducing point methods 


In this section, we discuss an approximation method based on inducing points, also called pseudo 
inputs, which are like a learned summary of the training data that we can condition on, rather than 
conditioning on all of it. 

Let X be the observed inputs, and fx = f(X) be the unknown vector of function values (for which 
we have noisy observations y). Let f, be the unknown function values at one or more test points 
X.. Finally, let us assume we have M additional inputs, Z, with unknown function values fz (often 
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AX AN, 


=f mm f) ooo am f, m f, -=f mm f) ooo am f, m 
(a) (b) 


Figure 18.12: Illustration of the graphical model for a GP on n observations, fin, and one test case, fx, with 
inducing variables u. The thick lines indicate that all variables are fully interconnected. The observations 
yi (not shown) are locally connected to each fi. (a) no approximations are made. (b) we assume fs is 
conditionally independent of fx given u. From Figure 1 of [QCRO05]. Used with kind permission of Joaquin 
Quinonero- Candela. 


denoted by u). The exact joint prior has the form 


ef) = | fo fe faafe = | wf. flte = (0, (R Ker) asain 


(We write p(fx, f+) instead of p(fx, f.|X, X), since the inputs can be thought of as just indices 
into the random function f.) 

We will choose fz in such a way that it acts as a sufficient statistic for the data, so that we can 
predict f, just using fz instead of fx, i.e., we assume f+ | fx|fz. Thus we approximate the prior 
as follows: 


Df, fx, fz) = (fel fx, fz) (Fx lfz)pl( fz) ~ pfl fz)p(fxlfz)p( Fz) (18.118) 


See Figure 18.12 for an illustration of this assumption, and Section 18.5.3.4 for details on how to 


29 choose the inducing set Z. (Note that this method is often called a “sparse GP”, because it makes 


predictions for f, using a subset of the training data, namely fz, instead of all of it, fx.) 
From this, we can derive the following train and test conditionals 
P(Fxlfz) =N(fx|Kx,z2Kz'zfz,Kx,x — Qx,x) (18.119) 
P(felfz) = N (FK 2K77 fz, Ks» — Qs) (18.120) 


The above equations can be seen as exact inference on noise-free observations fz. To gain 


37 computational speedups, we will make further approximations of the form Q xx ~ Kx x- Qxx 


and Qu ~ K, — Q.,.. We consider various choices for these values below. All of these choices 


39 result in an initial training cost of O(M? + NM?), and then take O(M) time for the predictive mean 
40 for each test case, and O(M7) time for the predictive variance. (Compare this to O(N?) training 
41 time and O(N) and O(N?) testing time for exact inference.) 


* 18.5.3.1 SOR/ DIC 


45 Suppose we assume Q x,x = 0 and Qax = 0, so the conditionals are deterministic. This is called the 
46 deterministic inducing conditional (DIC) approximation [QCR05], or the subset of regressors 
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(SOR) approximation [Sil85; SBO1]. The corresponding joint prior has the form 


, f) = N(0, (2% oc) 18.121 

gsor(fx, fx) (0, Ge Qe, ( ) 
Consequently the predictive distribution is 

qsor( fx ly) = N (fs IQ. xÂÛzixY, Qs = Qr xÔ xx») (18.122) 

= N (flo °K: ZEKz xy, Ka ZEKzx) (18.123) 


where we have defined Ôx x = Qx,x + 07Iy, and © = (o-?Kz.xKx,z+Kzz)71. 
This predictive distribution is equivalent to the usual one for GPs except we have replaced Kx, x 
by Qx,x. This is equivalent to performing GP inference with the following kernel function 


Ksor(xi, æj) = K(a;,Z)Kz',K(Z, £3) (18.124) 


The kernel matrix has rank M, so the GP is degenerate. Furthermore, the kernel will be near 0 
when 2; or æ; is far from one of the chosen points Z, which can result in an underestimate of the 
predictive variance. 


18.5.3.2 DTC 


One way to overcome the overconfidence of DIC is to only assume Q xx = 0, but let On = 
K, — Qs, be exact. This is called the deterministic training conditional or DTC method 
[SWLO3]. 

The corresponding joint prior has the form 


Qxx Qx,» 
c 5) * = 18.125 
alai = NO, (GEX Qe (18.125) 
Hence the predictive distribution becomes 
date(faly) = N (fx |Q4, xQy! xy K — Q, xQx Qx, x) (18.126) 
= N (falo? ae — Q.« +K, ZXKz«) (18.127) 


The predictive mean is the same as in SOR, but the variance is larger (since K, „ — Q,.,. is positive 
definite) due to the uncertainty of f, given fz. 
18.5.3.3 FITC 


A widely used approximation assumes q(fx|fz) is fully factorized, i.e, 


Ufx|fz) = Il P(fnlfz) = N(fx\Kx,2Kzicfz, diag(Kx x — Qx,x)) (18.128) 


n=1 


This is called the fully independent training conditional or FITC assumption, and was first 
proposed in [SG06a]. This throws away less uncertainty that the SOR and DTC methods, since it 
does not make any deterministic assumptions about the relationship between fx and fz. 
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The joint prior has the form 


date(Fxs fx) =N(O, Cae 7 boa ~ Kxx) o (18.129) 


The predictive distribution for a single test case is given by 
gate (fx ly) = N (fa Ikka zZ%Kz xA ty, Rex — qx + kg ukz,.) (18.130) 


where A ê diag(Kx, x — Qx,x +07Iy), and X £ (Kz.z + Kz xA Ky z) t. If we have a batch 
of test cases, we can assume they are conditionally independent (an approach known as fully 
independent conditional or FIC), and multiply the above equation. 

The computational cost is the same as for SOR, and DTC, but the approach avoids some of the 
pathologies due to a non-degenerate kernel. In particular, one can show that the FIC method is 
equivalent to exact GP inference with the following non-degenerate kernel: 


a (18.131) 
Ksor(zi xj) ift Aj 


Kgc(ai, £5) = 


18.5.3.4 Learning the inducing points 


So far, we have not specified how to choose the inducing points or pseudo inputs Z. We can treat 
these like kernel hyperparameters, and choose them so as to maximize the log marginal likelihood, 
given by 


log q(y|X, Z) = log J I E fal RE r A (18.132) 
= log / balt Zaks (18.133) 
= =: log |Qx,x + A] — Su" (Qx.x +A) `y- 5 log(27) (18.134) 


where the definition of A depends on the method, namely Asor = Aate = o°Iy, and Afte = 
diag(Kx,x — Qx,x) + o°Iy. 

If the input domain is R%, we can optimize Z € using gradient methods. However, one of 
the appeals of kernel methods is that they can handle structured inputs, such as strings and graphs 
(see Section 18.2.3). In this case, we cannot use gradient methods to select the inducing points. 
A simple approach is to select the inducing points from the training set, as in the subset of data 


pMa 


39 approach in Section 18.5.1, or using the efficient selection mechanism in [Cao+15]. However, we can 
40 also use discrete optimization methods, such as simulated annealing (Section 12.9.1), as discussed in 
41 [For+18a]. See Figure 18.13 for an illustration. 


— 18.5.4 Sparse variational methods 


45 In this section, we discuss a variational approach to GP inference that is similar to the inducing point 
46 methods in Section 18.5.3, but which generalizes it to also handle non-conjugate likelihoods. This is 
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kernel 

k(-, -) Z Faka Kee 
iam 
inducing —— maximize GP (0, k(., -)) — 


points log p(y | z) 


Wwir.t. Z A 


Figure 18.13: Illustration of how to choose inducing points from a discrete input domain (here DNA sequences 
of length 4) to maximize the log marginal likelihood. From Figure 1 of [For+18a]. Used with kind permission 
of Vincent Fortuin. 


called the sparse variational GP or SVGP approximation. For more details, see e.g., [BWR16; 
Lei+20]. (See also [WKS21] for connections between SVGP and the Nyström method.) 

To explain the idea behind SVGP, let us assume, for simplicity, that the function f is defined over 
a finite set VY of possible inputs, which we partition into three subsets: the training set X, a set of 
inducing points Z, and all other points (which we can think of as the test set), X.. (We assume 
these sets are disjoint.) Let fx, fz and f, represent the corresponding unknown function values 
on these points, and let f = [fx, fz, f+] be all the unknowns. (Here we work with a fixed-length 
vector f, but the result generalizes to Gaussian processes, as explained in [Mat-+16].) We assume 
the function is sampled from a GP, so p(f) = N(m(¥),K(¥,4)). 

The inducing point methods in Section 18.5.3 approximates the GP prior by assuming p( fx, fx, fz) © 
plf- fz)\p(fx|fz)p(fz). The inducing points fz are chosen to maximize the likelihood of the ob- 
served data. We then perform exact inference in this approximate model. By contrast, in this 
section, we will keep the model unchanged, but we will instead approximate the posterior p(f|y) 
using variational inference. This approach is known as the variational free energy (VFE) method 
for GPs [Tit09; Mat+16]; it is also called SVGP. 

In the VFE view, the inducing points Z and inducing variables fz are variational parameters, 
rather than model parameters, which avoids the risk of overfitting. Furthermore, one can show that 
as the number of inducing points m increases, the quality of the posterior consistently improves, 
eventually recovering exact inference. By contrast, in the classical inducing point method, increasing 
m does not always result in better performance [BWR16]. 

In more detail, the VFE approach tries to find an approximate posterior q(f) to minimize 
Dx (q(f) || p(fly)). The key assumption is that q(f) = a(f«, fx, fz) = p( fs, fx|fz)a(fz), where 
p( f+, fx|fz) is computed exactly using the GP prior, and q( fz) is learned, by minimizing K(q) = 
Dx (af) || p(fly)).? Intuitively, q( fz) acts as a “bottleneck” which “absorbs” all the observations 
from y; posterior predictions for elements of fx or f, are then made via their dependence on fz, 
rather than their dependence on each other. 


2. One can show that Dxx (q(f) || p(fly)) = Drew (e(fx, fz) || p(fx,Ffzly)), which is the original objective from 
[TitO9]. 
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We can derive the form of the loss, which is used to compute the posterior q( fz), as follows: 


K(q) = Dg (alf, fx, fz) | (fs, fx, fzly)) 
z Fs; fx fz) 
= | af fx Fa) log Pe af, dfx af, 
DEAE fzx (fz )vly) 
DEAE TZ] Z)P(fz)p(yl fx) 


i e Gea) 
= / (fa. Fx\ Fa) Fa) log eae df. dfx dfz 


z J d(fa) loz daas = J ol fxlfz)a(Fz) lee vlulfx)dfx dfz +0 


= De (4(fz) || p(Fz)) — Eagx) los plul fx) + 


where C = log p(y) is an irrelevant constant. 


2 J ol fe. Fxlfz)a(Fz) log df. dfx dfz 


(18.135) 


(18.136) 
(18.137) 
(18.138) 


(18.139) 


(18.140) 


We can alternatively write the objective as an evidence lower bound that we want to maximize: 


log p(y) = K(q) + Eq¢,) [log p(ylfx)] — Dex (a(fz) || p(fz)) 
> Exgx) [log p(y|fx)] — Dri (a(Fz) || p(fz)) = L(q) 


(18.141) 
(18.142) 


Now suppose we choose a Gaussian posterior approximation, q( fz) = N(fz|m,S). Since p(fz) = 
N(fz|0,K(Z, Z)), we can compute the KL term in closed form using the formula for KL divergence 


between Gaussians (Equation (5.80)). 


As for the expected log-likelihood term, we need to compute q( fx). Since q(fz) is Gaussian, this 


can be done in closed form as follows: 


q(fx|m,S) = J p(fx|fz: X, Z)a(fzlm, S)dfz = N (fx |i, 5) 
ji; = m(x;) + a(x)" (m — m(Z)) 
Miz = K(a;,@;) — a(x) (K(Z, Z) — S)a(a;) 
a(x;) = K(Z, Z)~'K(Z, x) 


Hence q( fn) =N(fnlfin; Snn), which we can use to compute the expected log likelihood: 


N 
tgx) log p(ylfx)] = XC Eqcy,,) log p(ynlfn)] 


n=1 


We discuss how to compute the expected loglikelihood below. 


41 18.5.4.1 Gaussian likelihood 


(18.147) 


If we have a Gaussian observation model, we can compute the expected log likelihood in closed form. 


In particular, if we assume m(a) = 0, we have 


z E ae 1 > 1 
Va(fn) [log N (Ynlfns B )] = log N (yn|k Kz zm, B * z 5 Pknn = gtt(SAn) 
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where knn = knn — kL Kz'zkn, kn is the wth column of Kz x and A, = Kz zknk, K3. 
Hence the overall ELBO has the form 


1 
L(q) = log N (y|Kx,zKz zm, 81) — 5 btr(Kx,2K7',8Kz'7Kz,x) (18.149) 


= piitr(Kx,x — Qx,x) — Dr (a(fz) || p(fz)) (18.150) 


To compute the gradients of this, we leverage the following result [OA09]: 


0 : o 
T UA (z|u,02) [h(a)] = En (e|p,02) Fa) (18.151) 
O .. 1, o? 
ee sN (z| p1,02) [h(a)] = 3 UN (a|p,02) [aro] (18.152) 
We then substitute h(x) with log p(yn|fn). Using this, one can show 
VmL(q) = K7 Kz, xy -Am (18.153) 
1 1 
Vs£(q) = oa =A (18.154) 
Setting the derivatives to zero gives the optimal solution: 
S=A7! (18.155) 
A = 8KZ',Kz,xKx,zKz', + Kz (18.156) 
m = BA K7 ;KZz xy (18.157) 
This is called sparse GP regression or SGPR [Tit09]. 
With these parameters, the lower bound on the log marginal likelihood is given by 
= = 1 
log p(y) 2 log N(yl0,Kx,2KzzKz,x + 8-11) — 56tr(Kx.x — Qx,x) (18.158) 


where Qy x = Kx, zKz7K z,x. (This is called the “collapsed” lower bound, since we have marginal- 
ized out fz.) If Z = X, then Kz,z = Kz,x = Kx,x, so the bound becomes tight, and we have 


log p(y) = log N (y|0, Kx x + B71). 


Equation (18.158) is almost the same as the log marginal likelihood for the DTC model in 
Equation (18.134), except for the trace term; it is this latter term that prevents overfitting, due to 
the fact that we treat fz as variational parameters of the posterior rather than model parameters of 


the prior. 


18.5.4.2 Non-Gaussian likelihood 


In this section, we briefly consider the case of non-Gaussian likelihoods, which arise when using 
GPs for classification or for count data (see Section 18.4). We can compute the gradients of the 
expected log likelihood by defining A( fn) = log p(yn|fn) and then using a Monte Carlo approximation 
to Equation (18.151) and Equation (18.152). In the case of a binary classifier, we can use the results 


in Table 18.2 to compute the inner ag h fn) and 


[wss21]. 


Sahl fn) terms. For a more general treatment, see 
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18.5.4.3 Minibatch SVI 


Computing the optimal variational solution in Section 18.5.4.1 requires solving a batch optimization 
problem, which takes O(M? + NM?) time. This may still be too slow if N is large, unless M is 
small, which compromises accuracy. 

An alternative approach is to perform stochastic optimization of the VFE objective, instead of 
batch optimization. This is known as stochastic variational inference (see Section 10.3.1). The 
key observation is that the log likelihood in Equation (18.147) is a sum of N terms, which we can 
approximate with minibatch sampling to compute noisy estimates of the gradient, as proposed in 
[HFL13]. 

In more detail, the objective becomes 


B 
£(q) = 3B > Eg) llogp(ynlfn)l| — Dex. (al Fz) Il O32) (18.159) 


neEB, 


where By is the b’th batch, and B is the number of batches. Since the GP model (with Gaussian 
likelihoods) is in the exponential family, we can efficiently compute the natural gradient (Section 6.4) 
of Equation (18.159) wrt the canonical parameters of q( fz); this converges much faster than following 
the standard gradient. See [HFL13] for details. 


18.5.5 Exploiting parallelization and structure via kernel matrix multiplies 


It takes O(N?) time to compute the Cholesky decomposition of Kx,x, which is needed to solve the 
linear system K,a = y and to compute |Kx,x|. An alternative to Cholesky decomposition is to 
use linear algebra methods, often called Krylov subspace methods based just on matrix vector 


= multiplication or MVM. These approaches are often much faster. 


In short, if the kernel matrix Kx, has special algebraic structure, which is often the case through 


= either the choice of kernel or the structure of the inputs, then it is typically easier to exploit this 
= structure in performing fast matrix multiplies. Moreover, even if the kernel matrix does not have 
= special structure, matrix multiplies are trivial to parallelize, and can thus be greatly accelerated by 
32- GPUs, unlike Cholesky based methods which are largely sequential. Algorithms based on matrix 
°< multiplies are in harmony with modern hardware advances, which enable significant parallelization. 


— 18.5.5.1 Using conjugate gradient and Lanczos methods 


36 We can solve the linear system K,a@ = y using conjugate gradients (CG). The key computational 
37 step in CG is the ability to perform MVMs. Let 7(K,) be the time complexity of a single MVM 
38 with K,. For a dense n x n matrix, we have T(K,) = n?; however, we can speed this up if K, is 
39 sparse or structured, as we discuss below. 


Even if K, is dense, we may still be able to save time by solving the linear system approximately. 


41 In particular, if we perform C iterations, CG will take O(C7(K,)) time. If we run for C = n, and 
42 T(K,) = n?, it gives the exact solution in O(n?) time. However, often we can use fewer iterations 
43 and still get good accuracy, depending on the condition number of K,. 


We can compute the log determinant of a matrix using the MVM primitive with a similar iterative 


45 method known as stochastic Lanczos quadrature [UCS17; Don+17a]. This takes O(L7r(K,)) 
46 time for L iterations. 
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18.5. SCALING GP INFERENCE TO LARGE DATASETS 


KEGGU (n=40708) 3DRoad (n=278319) Song (n=329820) 


—=  Subsampled (Exact GP) 
= = Full Dataset (SGPR m=512) 
== Full Dataset (SVGP m=1024) 
=m Full Dataset (Exact GP) 


0.6 
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Figure 18.14: RMSE on test set as a function of training set size using a GP with Matern 3/2 kernel with 
shared lengthscale across all dimensions. Solid lines: exact inference. Dashed blue: SGPR method (closed-form 
batch solution to the Gaussian variational approximation) of Section 18.5.4.1 with M = 512 inducing points. 
Dashed orange: SVGP method (SGD on Gaussian variational approxiation) of Section 18.5.4.8 with M = 1024 
inducing points. Number of input dimensions: KEGGU D = 27, 3DRoad D = 3, Song D = 90. From Figure 
4 of [Wan+19a]. Used with kind permission of Andrew Wilson. 


These methods have been used in the blackbox matrix-matrix multiplication (BBMM) 
inference procedure of [Gar+18a], which formulates a batch approach to CG that can be effectively 
parallelized on GPUs. Using 8 GPUs, this enabled the authors of [Wan-+19a] to perform exact 
inference for a GP regression model on N ~ 10+ datapoints in seconds, N ~ 10° datapoints in 
minutes, and N ~ 10° datapoints in hours. 

Interestingly, Figure 18.14 shows that exact GP inference on a subset of the data can often 
outperform approximate inference on the full data. We also see that performance of exact GPs 
continues to significantly improve as we increase the size of the data, suggesting that GPs are not only 
useful in the small-sample setting. In particular, the BBMM is an exact method, and so will preserve 
the non-parametric representation of a GP with a non-degenerate kernel. By contrast, standard 
scalable approximations typically operate by replacing the exact kernel with an approximation that 
corresponds to a parametric model. The non-parametric GPs are able to grow their capacity with 
more data, benefiting more significantly from the structure present in large datasets. 


18.5.5.2 Kernels with compact support 


Suppose we use a kernel with compact support, where K(æ,x') = 0 if ||a — a’|| > e for some 
threshold e (see e.g., [MR09]), then K, will be sparse, so T(K,) will be O(N). We can also induce 
sparsity and structure in other ways, as we discuss in Section 18.5.5.3. 


18.5.5.3 KISS 


One way to ensure that MVMs are fast is to force the kernel matrix to have structure. The structured 
kernel interpolation (SKI) method of [WN15] does this as follows. First it assumes we have a set 
of inducing points, with Gram matrix Kz, z. It then interpolates these values to predict the entries 
of the full kernel matrix using 


Ky x ~ WxKz7Wx (18.160) 
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where Wx is a sparse matrix containing interpolation weights. If we use cubic interpolation, each 
row only has 4 nonzeros. Thus we can compute (WxKz zW)v for any vector v in O(N + M?) 
time. 

Note that the SKI approach generalizes all inducing point methods. For example, we can recover the 
subset of regressors method (SOR) method by setting the interpolation weights to W = Kx, aK 7 
We can identify this procedure as performing a global Gaussian process interpolation strategy on the 
user specified kernel. See [WN15] and [WDN15] for more details. 

In 1d, we can further reduce the running time by choosing the inducing points to be on a regular 
grid, so that Kz z is a Toeplitz matrix. In higher dimensions, we need to use a multidimensional grid 
of points, resulting in Kz z being a Kronecker product of Toeplitz matrices. This enables matrix 
vector multiplication in O(N + M log M) time and O(N + M) space. The resulting method is called 
KISS-GP [WN15], which stands for “kernel interpolation for scalable, structured GPs”. 

Unfortunately, the KISS method can take exponential time in the input dimensions D when 
exploiting Kronecker structure in Kz z, due to the need to create a fully connected multidimensional 
lattice. In [Gar-+18b], they propose a method called SKIP, which stands for “SKI for products”. 
The idea is to leverage the fact that many kernels (including ARD) can be written as a product of 1d 
kernels: K(x, x’) = Ths. K4(ax,2x'). This can be combined with the 1d SKI method to enable fast 
MVMs. The overall running time to compute the log marginal likelihood (which is the bottleneck 
for kernel learning) using C iterations of CG and a Lanczos decomposition of rank L, becomes 
O(DL(N + M log M) + L'N log D+ CL?N). Typical values are L ~ 10! and C ~ 10°. 


18.5.5.4 Tensor train methods 


Consider the Gaussian VFE approach in Section 18.5.4. We have to estimate the covariance S and 
the mean m. We can represent S efficiently using Kronecker structure, as used by KISS. Additionally, 
we can represent m efficiently using the tensor train decomposition [Ose11] in combination with 
SKI [WN15]. The resulting TT-GP method can scale efficiently to billions of inducing points, as 
explained in [INK18]. 


31 18.5.6 Converting a GP to a SSM 


Consider a function defined on a 1d scalar input, such as a time index. For many kernels, the 
corresponding GP can be modeled using a stochastic differential equation. This induces a block 
tri-diagonal precision matrix for the posterior p(fi:r|y1.r). We can therefore convert this to a 
linear-Gaussian state space model (Section 29.1), and perform exact inference in O(T) time using 


>, Kalman smoothing, as explained in [SSH13; Ada+20]. This conversion can be done exactly for 
ag Matern kernels and approximately for Gaussian (RBF) kernels (see [SS19, Ch. 12]). In [SGF21], they 


describe how to reduce the linear dependence on T to log(T) time using a parallel prefix scan 
operator, that can be run efficiently on GPUs. 


4218.6 Learning the kernel 


In [Mac98], David MacKay asked: “How can Gaussian processes replace neural networks? Have we 


45 thrown the baby out with the bathwater?” This remark was made in the late 1990s, at the end of 
46 the second wave of neural networks. Researchers and practitioners had grown weary of the design 
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Figure 18.15: Some 1d GPs with RBF kernels but different hyper-parameters fit to 20 noisy observations. 
The hyper-parameters (£,07,0y) are as follows: (a) (1,1,0.1) (b) (3.0, 1.16, 0.89). Adapted from Figure 2.5 
of [RWO06]. Generated by gpr_demo_change_hparams.ipynb. 


decisions associated with neural networks — such as activation functions, optimization procedures, 
architecture design — and the lack of a principled framework to make these decisions. Gaussian 
processes, by contrast, were perceived as flexible and principled probabilistic models, which naturally 
followed from Radford Neal’s results on infinite neural networks [Nea96], which we discuss in more 
depth in Section 18.7. 

However, MacKay [Mac98] noted that neural networks could discover rich representations of data 
through adaptive hidden basis functions, while Gaussian processes with standard kernel functions, 
such as the RBF kernel, are essentially just smoothing devices. Indeed, the generalization properties 
of Gaussian processes hinge on the suitability of the kernel function. Learning the kernel is how 
we do representation learning with Gaussian processes, and in many cases will be crucial for good 
performance — especially when we wish to perform extrapolation, making predictions far away from 
the data [WA13; Wil+14]. 

As we will see, learning a kernel is in many ways analogous to training a neural network. Moreover, 
neural networks and Gaussian processes can be synergistically combined through approaches such as 
deep kernel learning (see Section 18.6.6) and NN-GPs (Section 18.7.2). 


18.6.1 Empirical Bayes for the kernel parameters 


Suppose, as in Section 18.3.2, we are performing ld regression using a GP with an RBF kernel. Since 
the data has observation noise, the kernel has the following form: 


: (£p — Zq)”) + 02 5pq (18.161) 


Kyl ipta) = o$ exP(— zz 


Here @ is the horizontal scale over which the function changes, oF controls the vertical scale of 
the function, and o? is the noise variance. Figure 18.15 illustrates the effects of changing these 
parameters. We sampled 20 noisy data points from the SE kernel using (¢,0/,0,) = (1,1,0.1), 
and then made predictions various parameters, conditional on the data. In Figure 18.15(a), we use 
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(L,of, oy) = (1,1,0.1), and the result is a good fit. In Figure 18.15(b), we increase the length scale 
to £ = 3; now the function looks smoother, but we are arguably underfitting. 

To estimate the kernel parameters 0 (sometimes called hyperparameters), we could use exhaustive 
search over a discrete grid of values, with validation loss as an objective, but this can be quite slow. 
(This is the approach used by nonprobabilistic methods, such as SVMs, to tune kernels.) Here we 
consider an empirical Bayes approach, which will allow us to use continuous optimization methods, 
which are much faster. In particular, we will maximize the marginal likelihood 


p(ylX, 0) = i v(ylf, X) FIX, O)df (18.162) 


(The reason it is called the marginal likelihood, rather than just likelihood, is because we have marginal- 
ized out the latent Gaussian vector f.) Since p(f|X) = N(f|0, K), and p(y|f) = IA, N (ynl fn, 02), 
the marginal likelihood is given by 


ae 1 N. 
log p(y|X, @) = log N(y|0, Ko) = —59Ko'y ~ 5 log |Kol — a log (277) (18.163) 


where the dependence of K, on @ is implicit. The first term is a data fit term, the second term is a 
model complexity term, and the third term is just a constant. To understand the tradeoff between 
the first two terms, consider a SE kernel in 1D, as we vary the length scale £ and hold o? fixed. 
Let J(¢) = —log p(y|X, £). For short length scales, the fit will be good, so y' Kz ty will be small. 
However, the model complexity will be high: K will be almost diagonal, since most points will not 
be considered “near” any others, so the log |K,| will be large. For long length scales, the fit will be 
poor but the model complexity will be low: K will be almost all 1’s, so log |K,| will be small. 
We now discuss how to maximize the marginal likelihood. One can show that 


ð 1 dK 1 OK 
=] X, 0) = y ' K7! K7! K7! 18.164 
50, og p(y|X, 0) 3Y Ko a0, Ke y 5 tr F 50, ) (18.164) 
a! T -1, 0K. 
= 5 tt ((a0 K7) 0, (18.165) 


33 where œ = K7'y. It takes O(N?) time to compute K>!, and then O(N?) time per hyper-parameter 


to compute the gradient. 
The form of He depends on the form of the kernel, and which parameter we are taking derivatives 
2 


36 with respect to. Often we have constraints on the hyper-parameters, such as o? > 0. In this case, we 
37 can define 0 = log(o7), and then use the chain rule. 


Given an expression for the log marginal likelihood and its derivative, we can estimate the kernel 


39 parameters using any standard gradient-based optimizer. However, since the objective is not convex, 
40 local minima can be a problem, as we illustrate below, so we may need to use multiple restarts. 


= 18.6.1.1 Example 


Consider Figure 18.16. We use the SE kernel in Equation (18.161) with oF = 1, and plot 


45 log p(y|X, £, 07) (where X and y are the 7 data points shown in panels b and c as we vary £ 
46 and g: The two local optima are indicated by + in panel (a). The bottom left optimum corresponds 
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Figure 18.16: Illustration of local minima in the marginal likelihood surface. (a) We plot the log marginal 
likelihood vs o? and £, for fixed oF = 1, using the 7 data points shown in panels b and c. (b) The function 
corresponding to the lower left local minimum, (€,02) ~ (1,0.2). This is quite “wiggly” and has low noise. (c) 
The function corresponding to the top right local minimum, (£, o2) = (10,0.8). This is quite smooth and has 
high noise. The data was generated using (€,07) = (1,0.1). Adapted from Figure 5.5 of [RW06]. Generated 


by gpr_demo_ marglik.ipynb. 


to a low-noise, short-length scale solution (shown in panel b). The top right optimum corresponds to 
a high-noise, long-length scale solution (shown in panel c). With only 7 data points, there is not 
enough evidence to confidently decide which is more reasonable, although the more complex model 
(panel b) has a marginal likelihood that is about 60% higher than the simpler model (panel c). With 
more data, the more complex model would become even more preferred. 

Figure 18.16 illustrates some other interesting (and typical) features. The region where o3 x1 
(top of panel a) corresponds to the case where the noise is very high; in this regime, the marginal 
likelihood is insensitive to the length scale (indicated by the horizontal contours), since all the data 
is explained as noise. The region where £ ~ 0.5 (left hand side of panel a) corresponds to the case 
where the length scale is very short; in this regime, the marginal likelihood is insensitive to the noise 
level (indicated by the vertical contours), since the data is perfectly interpolated. Neither of these 
regions would be chosen by a good optimizer. 


18.6.2 Bayesian inference for the kernel parameters 


When we have a small number of datapoints (e.g., when using GPs for blackbox optimization, as 
we discuss in Section 6.8), using a point estimate of the kernel parameters can give poor results 
[Bulll; WF 14]. As a simple example, if the function values that have been observed so far are all 
very similar, then we may estimate ô ~ 0, which will result in overly confident predictions.’ 

To overcome such overconfidence, we can compute a posterior over the kernel parameters. If the 
dimensionality of 0 is small, we can compute a discrete grid of possible values, centered on the MAP 


3. In [WSNO0; BBV11b], they show how we can put a conjugate prior on g? and integrate it out, to generate a Student 
version of the GP, which is more robust. 
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Figure 18.17: Three different approximations to the posterior over hyper-parameters: grid-based, Monte Carlo, 
and central composite design. From Figure 8.2 of [Van10]. Used with kind permission of Jarno Vanhatalo. 


estimate Ô (computed as above). We can then approximate the posterior using 


S 
P(FID) = X` pf ID, 0s)p(0s|D)ws (18.166) 


s=1 


where ws denotes the weight for grid point s. 

In higher dimensions, a regular grid suffers from the curse of dimensionality. One alternative is 
place grid points at the mode, and at a distance +1sd from the mode along each dimension, for a 
total of 2|@| + 1 points. This is called a central composite design [RMC09]. See Figure 18.17 for 
an illustration. 

In higher dimensions, we can use Monte Carlo inference for the kernel parameters when computing 
Equation (18.166). For example, [MA10] shows how to use slice sampling (Section 12.4.1) for this task, 


29 [Hen+15] shows how to use HMC (Section 12.5), and [BBV11a] shows how to use SMC (Chapter 13). 


In Figure 18.18, we illustrate the difference between kernel optimization vs kernel inference. We fit 
a 1d dataset using a kernel of the form 


K(r) = a? Ksg(r; T)Keos(r; p1) + 05K 32(r; p2) (18.167) 


34 where Kgp(r; £) is the squared exponential kernel (Equation (18.12)), Keos (r; p1) is the cosine kernel 
35 (Equation (18.19)), and K32(r; p2) is the Matern 3 kernel (Equation (18.15)). We then compute a 
36 point-estimate of the kernel parameters using empirical Bayes, and posterior samples using HMC. 
37 We then predicting the posterior mean of f on a 1d test set by plugging in the MLE or averaging 
38 over samples. We see that the latter captures more uncertainty (beyond the uncertainty captured by 
39 the Gaussian itself). 


= 18.6.3 Multiple kernel learning for additive kernels 


43 A special case of kernel learning arises when the kernel is a sum of B base kernels 


K(a, 2’) = X wKy(a, 2’) (18.168) 
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Figure 18.18: Difference between estimation and inference for kernel hyper-parameters. (a) Empirical 
Bayes approach based on optimization. We plot the posterior predicted mean given a plug-in estimate, 


E [ŒD ô|. (b) Bayesian approach based on HMC. We plot the posterior predicted mean, marginalizing 
over hyper-parameters, E[f(x)|D]. Generated by gp_kernel_ opt.ipynb. 


(a) (b) 


Figure 18.19: Comparison of different additive model classes for a 4d function. Circles represent different 
interaction terms, ranging from first-order to fourth-order. Left: hierarchical kernel learning uses a nested 
hierarchy of terms. Right: additive GPs use a weighted sum of additive kernels of different orders. Color 
shades represent different weighting terms. Adapted from Figure 6.2 of [Duv14]. 


Optimizing the weights w, > 0 using structural risk minimization is known as multiple kernel 
learning; see e.g., [Rak-+08] for details. 

Now suppose we constrain the base kernels to depend on a subset of the variables. Furthermore, 
suppose we enforce a hierarchical inclusion property (e.g., including the kernel k123 means we must 
also include k12, k13 and k23), as illustrated in Figure 18.19(left). This is called hierarchical kernel 
learning. We can find a good subset from this model class using convex optimization [Bac09]; 
however, this requires the use of cross validation to estimate the weights. A more efficient approach 
is to use the empirical Bayes approach described in [DNR11]. 

In many cases, it is common to restrict attention to first order additive kernels, i.e., K(x, x’) = 
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Figure 18.20: Predictive distribution of each term im a GP-GAM model applied to a dataset with 8 continuous 
inputs and 1 continuous output, representing the strength of some concrete. From Figure 2.7 of [Duv14]. 
Used with kind permission of David Duvenaud. 


26 Eai Kalza, x4). The resulting function has the form 


f(x) = fı(zı) +... + fo(zp) (18.169) 


a, This is called a generalized additive model or GAM. 


Figure 18.20 shows an example of this, where each base kernel has the form Kalza, £4) = 


ay 079E (ra, £4|la), In Figure 18.20, we see that the o? terms for the coarse and fine features are 


set to zero, indicating that these inputs have no impact on the response variable. 
[DBW20] considers additive kernels operating on different linear projections of the inputs: 


B 
K(a,2') = X` wK (Pog, Po’) (18.170) 
b=1 


Surprisingly, they show that these models can match or exceed the performance of kernels operating 
on the original space, even when the projections are into a single dimension, and not learned. In 


41 other words, it is possible to reduce many regression problems to a single dimension without loss 
42 in performance. This finding is particularly promising for scalable inference, such as KISS (see 
43 Section 18.5.5.3), and active learning, which are greatly simplified in a low dimensional setting. 


More recently, [LBH22] has proposed the orthogonal additive kernel (OAK), which imposes 
an orthogonality constraint on the additive functions. This ensures an identifiable, low-dimensional 


46 representation of the functional relationship, and results in improved performance. 
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Figure 18.21: Example of a search tree over kernel expressions. Adapted from Figure 3.2 of [Duv14]. 


18.6.4 Automatic search for compositional kernels 


Although the above methods can estimate the hyperparameters of a specified set of kernels, they do 
not choose the kernels themselves (other than the special case of selecting a subset of kernels from a 
set). In this section, we describe a method, based on [Duv+13], for sequentially searching through 
the space of increasingly complex GP models so as to find a parsiminous description of the data. 

We start with a simple kernel, such as the white noise kernel, and then consider replacing it 
with a set of possible alternative kernels, such as an SE kernel, RQ kernel, etc. We use the BIC 
score (Section 3.9.6.2) to evaluate each candidate model (choice of kernel) m. This has the form 
BIC(m) = log p(D|m) — $|m| log N, where p(D|m) is the marginal likelihood, and |m] is the number of 
parameters. The first term measures fit to the data, and the second term is a complexity penalty. We 
can also consider replacing a kernel by the addition of two kernels, k — (k +k’), or the multiplication 
of two kernels, k > (k x k’). See Figure 18.21 for an illustration of the search space. 

Searching through this space is similar to what a human expert would do. In particular, if we 
find structure in the residuals, such as periodicity, we can propose a certain “move” through the 
space. We can also start with some structure that is assumed to hold globally, such as linearity, but 
if we find this only holds locally, we can multiply the kernel by an SE kernel. We can also add input 
dimensions incrementally, to capture higher order interactions. 

Figure 18.22 shows the output of this process applied to a dataset of monthly totals of international 
airline passengers. The input to the GP is the set of time stamps, x = 1 : t; there are no other 
features. 

The observed data lies in between the dotted vertical lines; curves outside of this region are 
extrapolations. We see that the system has discovered a fairly interpretable set of patterns in the 
data. Indeed, it is possible to devise an algorithm to automatically convert the output of this search 
process to a natural language summary, as shown in [Llo+14]. In this example, it summarizes the 
data as being generated by the addition of 4 underlying trends: a linearly increasing function; an 
approximately periodic function with a period of 1.0 years, and with linearly increasing amplitude; a 
smooth function; and uncorrelated noise with linearly increasing standard deviation. 

Recently, [Sun+18] showed how to create a DNN which learns the kernel given two input vectors. 
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Figure 18.22: Top row: airline dataset and posterior distribution of the model discovered after a search of 
depth 10. Subsequent rows: predictions of the individual components. From Figure 8.5 of [Duv14], based on 
[Llo+14]. Used with kind permission of David Duvenaud. 


2. The hidden units are defined as sums and products of elementary kernels, as in the above search 
== based approach. However, the DNN can be trained in a differentiable way, so is much faster. 


= 18.6.5 Spectral mixture kernel learning 


33 Any shift-invariant (stationary) kernel can be converted via the Fourier transform to its dual form, 


known as its spectral density. This means that learning the spectral density is equivalent to 


35 learning any shift-invariant kernel. For example, if we take the Fourier transform of an RBF kernel, 
36 we get a Gaussian spectral density centered at the origin. If we take the Fourier transform of a 
37 Matern kernel, we get a Student-t spectral density centred at the origin. Thus standard approaches 
38 to multiple kernel learning, which typically involve additive compositions of RBF and Matern kernels 


with different length-scale parameters, amount to density estimation with a scale mixture of Gaussian 


40 or Student-t distributions at the origin. Such models are very inflexible for density estimation, and 
41 thus also very limited in being able to perform kernel learning. 


On the other hand, scale-location mixture of Gaussians can model any density to arbitrary precision. 


43 Moreover, with even a small number of components these mixtures of Gaussians are highly flexible. 
44 Thus a spectral density corresponding to a scale-location mixture of Gaussians forms an expressive 
45 basis for all shift-invariant kernels. One can evaluate the inverse Fourier transform for a Gaussian 
46 mixture analytically, to derive the spectral mixture kernel [WA13], which we can express for 
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Figure 18.23: Illustration of a GP with a spectral mixture kernel in 1d. (a) Learned vs true kernel. (b) 
Predictions using learned kernel. Generated by gp_ spectral_ mixture.ipynb. 


700 


=a 95% CR T 95% CR 
00 
£ — Train © 600 asa 
2 — Test 3 me 
—~ Tesi 
380; | MA © 500 
T no —RQ 
E aso | PER © 400 
5 --- SE 2 
— SM 
S 340 @ 300 
oO a. 
N 
Q 2 200 
O 320 = n 
A Ti AI N 
1968 1977 1986 1995 2004 Wao 1951 1953 1955 1957 1959 1961 
Year Year 


Figure 18.24: Extrapolations (point predictions and 95% credible set) on CO2 and airline datasets using 
Gaussian processes with Matern, rational quadratic, periodic, RBF (SE), and spectral mixture kernels, each 
with hyperparameters learned using empirical Bayes. From [Wil14]. 


one-dimensional inputs x as: 


K(x, 2’) = De wicos((x — a’)(2mp;)) exp(—20? (x — x’)? u;) (18.171) 


a 


The mixture weights w;, as well as the means u; and variances v; of the Gaussians in the spectral 
density, can be learned by empirical Bayes optimization (Section 18.6.1) or in a fully-Bayesian 
procedure (Section 18.6.2) [Jan+17]. We illustrate the former approach in Figure 18.23. 

By learning the parameters of the spectral mixture kernel, we can discover representations that 
enable extrapolation — to make reasonable predictions far away from the data. For example, in Sec- 
tion 18.8.1, compositions of kernels are carefully hand-crafted to extrapolate CO2 concentrations. But 
in this instance, the human statistician is doing all of the interesting representation learning. Figure 
Figure 18.24 shows Gaussian processes with learned spectral mixture kernels instead automatically 
extrapolating on CO. and airline passenger problems. 
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Figure 18.25: Deep Kernel Learning: A Gaussian process with a deep kernel maps D dimensional inputs x 
through L parametric hidden layers followed by a hidden layer with an infinite number of basis functions, 
with base kernel hyperparameters @. Overall, a Gaussian process with a deep kernel produces a probabilistic 
mapping with an infinite number of adaptive basis functions parametrized by y = {w,0}. All parameters y 
are learned through the marginal likelihood of the Gaussian process. From Figure 1 of [Wil+16]. 


These kernels can also be used to extrapolate higher dimensional large-scale spatio-temporal 
patterns. Large datasets can provide relatively more information for expressive kernel learning. 
However, scaling an expressive kernel learning approach poses different challenges than scaling a 
standard Gaussian process model. One faces additional computational constraints, and the need 
to retain significant model structure for expressing the rich information available in a large dataset. 
Indeed, in Figure 18.24 we can separately understand the effects of the kernel learning approach and 
scalable inference procedure, in being able to discover structure necessary to extrapolate textures. 
An expressive kernel model and a scalable inference approach that preserves a non-parametric 


29 representation are needed for good performance. 


Structure exploiting inference procedures, such as Kronecker methods, as well as KISS-GP and 


31 conjugate gradient based approaches, are appropriate for these tasks — since they generally preserve 
32 or exploit existing structure, rather than introducing approximations that corrupt the structure. 


Spectral mixture kernels combined with these scalable inference techniques have been used to great 


34 effect for spatiotemporal extrapolation problems, including land-surface temperature forecasting, 


epidemiological modeling, and policy-relevant applications. 


18.6.6 Deep kernel learning 


Deep kernel learning [SH07; Wil+-16] combines the structural properties of neural networks with 


40 the non-parametric flexibility and uncertainty representation provided by Gaussian processes. For 


example, we can define a “deep RBF kernel” as follows: 


1 
Ko(x, x’) = exp -zzz llho (x) — hg (x’)||? (18.172) 


45 where h (x) are the outputs of layer L from a DNN. We can then learning the parameters 0 by 
46 maximizing the marginal likelihood of the Gaussian processes. 
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(a) (b) 


Figure 18.26: Modeling a discontinuous function with (a) a GP with a “shallow” Matern 3 kernel, and (b) a 
GP with a “deep” MLP + Matern kernel. Generated by gp_ deep_kernel_learning.ipynb. 
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Figure 18.27: Left: The learned covariance matrix of a deep kernel with spectral mixture base kernel on a set 
of test cases for the Olivetti faces dataset, where the test samples are ordered according to the orientations 
of the input faces. Middle: The respective covariance matrix using a deep kernel with RBF base kernel. 
Right: The respective covariance matrix using a standard RBF kernel. From Figure 5 of [Wil+16]. 


This framework is illustrated in Figure 18.25. We can understand the neural network features as 
inputs into a base kernel. The neural network can either be (i) pre-trained, (ii) learned jointly with 
the base kernel parameters, or (iii) pre-trained and then fine-tuned through the marginal likelihood. 
This approach can be viewed as a “last-layer” Bayesian model, where a Gaussian process is applied 
to the final layer of a neural network. The base kernel often provides a good measure of distance in 
feature space, desirably encouraging predictions to have high uncertainty as we move far away from 
the data. 

We can use deep kernel learning to help the GP learn discontinuous functions, as illustrated in 
Figure 18.26. On the left we show the results of a GP with a standard Matern 3 kernel. It is clear 
that the out-of-sample predictions are poor. On the right we show the results of the same model 
where we first transform the input through a learned 2 layer MLP (with 15 and 10 hidden units). It 
is clear that the model is working much better. 

As a more complex example, we consider a regression problem where we wish to map faces (vectors 
of pixel intensities) to a continuous valued orientation angle. In Figure 18.27, we evaluate the deep 
kernel matrix (with RBF and spectral mixture base kernels, discussed in Section 18.6.5) on data 
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ordered by orientation angle. We can see that the learned deep kernels, in the left two panels, have a 
pronounced diagonal band, meaning that they have discovered that faces with similar orientation 
angles are correlated. On the other hand, in the right panel we see that the entries even for a learned 
RBF kernel are highly diffuse. Since the RBF kernel essentially uses Euclidean distance as a metric 
for similarity, it is unable to learn a representation that effectively solves this problem. In this case, 
one must do highly non-Euclidean metric learning. 

However, [ORW21] show that the approach to DKL based on maximizing the marginal likelihood 
can result in overfitting that is worse than standard DNN learning. They propose a fully Bayesian 
approach, in which they use SGLD (Section 12.7.1) to sample the DNN weights as well as the GP 
hyperparameters. 


18.7 GPs and DNNs 


In Section 18.6.6, we showed how we can combine the structural properties of neural networks with 
GPs. In Section 18.7.1 we show that, in the limit of infinitely wide networks, a neural network defines 
a GP with a certain kernel. These kernels are fixed, so the method is not performing representation 
learning, as a standard neural network would (see e.g., [COB18; Woo+19]). Nonetheless, these kernels 
are interesting in their own right, for example in modelling non-stationary covariance structure. 
In Section 18.7.2, we discuss the connection between SGD training of DNNs and GPs. And in 
Section 18.7.3, we discuss deep GPs, which are similar to DNNs in that they consist of many layers 
of functions which are composed together, but each layer is a nonparametric function. 


18.7.1 Kernels derived from infinitely wide DNNs (NN-GP) 


In this section, we show that an MLP with one hidden layer, whose width goes to infinity, and which 
has a Gaussian prior on all the parameters, converges to a Gaussian process with a well-defined 
kernel.“ This result was first shown for in [Nea96; Wil98], and was later extended to deep MLPs in 
[DFS16; Lee+18], to CNNs in [Nov+19], and to general DNNs in [Yan19]. The resulting kernel is 
called the NN-GP kernel [Lee+18]. 

We will consider the following model: 


A 
fe(@) = bp + X` vjnh;(@), hy(w) = p(uoj + 2" uy) (18.173) 
j=l 


33 where H is the number of hidden units, and y() is some nonlinear activation function, such as ReLU. 
2° We will assume Gaussian priors on the parameters: 


bk pg N (0, 00), Vik Ce N (0, ov), Uo; ~ N (0,00), uj Pd N (0, =) (18.174) 


39 Let 0 = {bk, Vik, Uoj, Uj} be all the parameters. The expected output from unit k when applied to 
40 one input vector is given by 


H H 
to [fu (æ)] = Eo |b, + > virhj(æ) = Eo [bx] +5 to [vik] Eu [h;(æ)] = 0 (18.175) 


j=1 


=0 


=2 4. Our presentation is based on http://cbl.eng.cam.ac.uk/pub/Intranet/MLG/ReadingGroup/presentation_ 


46 matthias.pdf. 


A 
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1 
2 The covariance in the output for unit k when the function is applied to two different inputs is given 
3 by the following:° 

4 

‘: H H 

; 2o [fe (x) fe(a’)] = Eo | | br +X vjrhj(æ) | | be +X vjrhj(æ) (18.176) 
7 j=1 j=1 

8 H 

9 = 05 + > Eo [vik] Eu [hy (w)hy(a’)] = of + of HEs [h;(æ)h;(æ")] (18.177) 
10 j=1 

11 


Now consider the limit H — oo. We scale the magnitude of the output by defining o? = w/H. 
Since the input to k’th output unit is an infinite sum of random variables (from the hidden units 
h;(a)), we can use the central limit theorem to conclude that the output converges to a Gaussian 
with mean and variance given by 


a [fi (x)] =0, V [fi (x)] = of + wE,, [h(x)?] (18.178) 


Furthermore, the joint distribution over {fx(a@n) : n = 1: N} for any N > 2 converges to a 
multivariate Gaussian with covariance given by 


1 [fe (@) fk(x')] = 05 + wEy [h(a )h(a’)] = K(x, x’) (18.179) 
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Thus the MLP converges to a GP. To compute the kernel function, we need to evaluate 


C(a, 2’) = Ey [hluo + u'æ)h(uo + u'a’)| = Ey [h(a 2)h(a' 2’) (18.180) 


where we have defined & = (1, x) and & = (uo, u). Let us define 
2 
~ (7% 0 
Š= E ) (18.181) 
Then we have 
C(x, 2’) = I h(a" z)h(ù" r’) N (a0, $)dù (18.182) 


This can be computed in closed form for certain activation functions, as shown in Table 18.4. 

This is sometimes called the neural net kernel. Note that this is a non-stationary kernel, and 
sample paths from it are nearly discontinuous and tend to constant values for large positive or 
negative inputs, as illustrated in Figure 18.28. 


18.7.2 Neural tangent kernel (NTK) 


In Section 18.7.1 we derived the NN-GP kernel, under the assumption that all the weights are random. 
A natural question is: can we derive a kernel from a DNN after it has been trained, or more generally, 
while it is being trained. It turns out that this can be done, as we show below. 


5. We are using the fact that u ~ N (0, ø?) implies E [u?] = YV [u] = o°. 
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erf(a' a) 2 arcsin( fı (#, 2’) 
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0 
ReLU(@a) 2@*) sin(o(@, z) + E eT Se! 


T 


Table 18.4: Some neural net GP kernels. Here we define fi(Z,Z’) = JESS fo(&, 2) = 


[332 @|] Ša’, (č) = y (ETSE) BE"), and 0(&, &') = arccos(fa(@,#’)). Results are derived in 
[Wil98; CS09]. 


oo = 5 


—4 —2 0 2 4 —4 —2 0 2 4 


Figure 18.28: Sample output from a GP with an NNGP kernel derived from an infinitely wide one layer MLP 
with activation function of the form h(x) = erf(x-u+uo) where u ~ N (0,0) and uo ~ N (0,00). Generated 
by nngp_ 1d.ipynb. Used with kind permission of Matthias Bauer. 


Let f = [f(#n;0)|_, be the N x 1 prediction vector, let V£ = Lape lan be the N x 1 loss 


gradient vector, let 0 = [Op]F be the P x 1 vector of parameters, and let Vef = S| be the 
P 

P x N matrix of partials. Suppose we perform continuous time gradient descent with fixed learning 

rate 7. The parameters evolve over time as follows: 


O61 = —nVoL( fi) = -nVo ft: V Llf) (18.183) 


= Thus the function evolves over time as follows: 


e Je Jẹ Je Jẹ Je Je TR Jw [w jw 
A [è 1& IÈ l& [B [e [8 1S lg IS 


fe = Vo fl WO. = —nVofi Vo ft- VELI) = —nTe- Ve L(t) (18.184) 


where 7; is the N x N kernel matrix 


P 
Tile, 2!) © Vo file) - Vo fila’) = > Ce) M a) (18.185) 


45 If we let the learning rate 7 become infinitesimally small, and the widths go to infinity, one can show 
46 that this kernel converges to a constant matrix, this is known as the neural tangent kernel or 
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18.8. GAUSSIAN PROCESSES FOR TIMESERIES FORECASTING 


NTK [JGH18]: 
T (a, 2") È Vo f(x; 0%): Vo f(x; Ao) (18.186) 


Details on how to compute this kernel for various models, such as CNNs, graph neural nets, and 
general neural nets, can be found in [Aro+19; Du+19; Yan19]. A software libary to compute the 
NN-GP kernel and NTK is available in [Ano19]. 

The assumptions behind the NTK results in the parameters barely changing from their initial 
values (which is why a linear approximation around the starting parameters is valid). This can 
still lead to a change in the final predictions (and zero final training error), because the final layer 
weights can learn to use the random features just like in kernel regression. However, this phenomenon 

— which has been called “lazy training” [COB18] — is not representative of DNN behavior in 
practice [Woo+19], where parameters often change a lot. Fortunately it is possible to use a different 
parameterization which does result in feature learning in the infinite width limit [YH21]. 


18.7.3 Deep GPs 


A deep Gaussian process or DGP is a composition of GPs [DL13]. More formally, a DGP of L 
layers is a hierachical model of the form 


DGP(a) = fro: o file), AO = (FOO, LP OL FË ~ GPO, Kil) (18.187) 


This is similar to a deep neural network, except the hidden nodes are now hidden functions. 

A natural question is: what is gained by this approach compared to a standard GP? Although 
conventional single-layer GPs are nonparametric, and can model any function (assuming the use of a 
non-degenerate kernel) with enough data, in practice their performance is limited by the choice of 
kernel. It is tempting to think that deep kernel learning (Section 18.6.6) can solve this problem, but 
in theory a GP on top of a DNN is still just a GP. However, one can show that a composition of GPs 
is strictly more general. Unfortunately, inference in deep GPs is rather complicated, so we leave the 
details to Supplementary Section 18.1. See also [Jak21] for a recent survey on this topic. 


18.8 Gaussian processes for timeseries forecasting 


It is possible to use Gaussian processes to perform timeseries forecasting (see e.g., [Rob+13]). The 
basic idea is to model the unknown output as a function of time, f(t), and to represent a prior about 
the form of f as a GP; we then update this prior given the observed evidence, and forecast into the 
future. Naively this would take O(T?) time. However, for certain stationary kernels, it is possible to 
reformulate the problem as a linear-Gaussian state space model, and then use the Kalman smoother 
to perform inference in O(T) time, as explained in [SSH13; SS19; Ada+20]. This conversion can be 
done exactly for Matern kernels and approximately for Gaussian (RBF) kernels (see [SS19, Ch. 12]). 
In [SGF 21], they describe how to reduce the linear dependence on T to log(T) time using a parallel 
prefix scan operator, that can be run efficiently on GPUs (see Section 8.3.3.3). 


18.8.1 Example: Mauna Loa 


In this section, we use the Mauna Loa CO, dataset from Section 29.12.5.1. We show the raw 
data in Figure 18.29(a). We see that there is periodic (or quasi-periodic) signal with a year-long 
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Figure 18.29: (a) The observed Mauna Loa CO2 time series. (b) Forecasts from a GP. Generated by 
gp_mauna_ loa.ipynb. 


period superimposed on a long term trend. Following [RW06, Sec 5.4.3], we will model this with a 
composition of kernels: 


K(r) = Ki (r) + Ko(r) + K3(r) + Ka(r) (18.188) 
where K;(t,t’) = Kilt — t’) for the i'th kernel. 


To capture the long term smooth rising trend, we let Kı be a squared exponential (SE) kernel, 
where ĝo is the amplitude and 0; is the length scale: 


2 
Kı (r) = 62 exp (-=) (18.189) 
1 


To model the periodicity, we can use periodic or exp-sine-squared kernel from Equation (18.18) 


— with a period of 1 year. However, since it is not clear if the seasonal trend is exactly periodic, we 
= multiply this periodic kernel with another SE kernel to allow for a decay away from periodicity; the 
= result is K2, where 62 is the magnitude, 03 is the decay time for the periodic component, 04 = 1 is 
= the period, and @5 is the smoothness of the periodic component. 


2 
Ko(r) = 03 exp (5 — 0; sin? (=) (18.190) 
3 


To model the (small) medium term irregularitries, we use a rational quadratic kernel (Equa- 
41 tion (18.20)): 
2 ei 
K = 06 |1+—— 18.191 
a(r) = 6 [1+ ag (18.191) 


46 where Os is the magnitude, 67 is the typical length scale, and 0g is the shape parameter. 
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The magnitude of the independent noise can be incorporated into the observation noise of the 
likelihood function. For the correlated noise, we use another SE kernel: 


2 
Ka(r) = 62 exp (-s ) (18.192) 
10 


where 09 is the magnitude of the correlated noise, and 619 is the length scale. (Note that the 
combination of K; and Ky is non-identifiable, but this does not affect predictions.) 

We can fit this model by optimizing the marginal likelihood wrt 0 (see Section 18.6.1). The 
resulting forecast is shown in Figure 18.29(b). 
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1 9 Beyond the iid assumption 


19.1 Introduction 


The standard approach to supervised ML assumes the training and test sets both contain independent 
and identically distributed (iid) samples from the same distribution. However, there are many 
settings in which the test distribution may be different from the training distribution; this is known 
as distribution shift, as we discuss in Section 19.2. 

In some cases, we may have data from multiple related distributions, not just train and test, as 
we discuss in Section 19.6. We may also encounter data in a streaming setting, where the data 
distribution may be changing continuously, or in a piecewise constant fashion, as we discuss in 
Section 19.7. Finally, in Section 19.8, we discuss settings in which the test distribution is chosen by 
an adversary to minimize performance of a prediction system. 


19.2 Distribution shift 


Suppose we have a labeled training set from a source distribution p(x, y) which we use to fit a 
predictive model p(y|a). At test time we encounter data from the target distribution q(x, y). 
If p Æ q, we say that there has been a distribution shift or datatset shift [QC+08; BD+10]. 
This can adversely affect the performance of predictive models, as we illustrate in Section 19.2.1. 
In Section 19.2.2 we give a taxonomy of some kinds of distribution shift using the language of 
causal graphical models. We then proceed to discuss a variety of strategies that can be adopted to 
ameliorate the harm caused by distribution shift. In particular, in Section 19.3, we discuss techniques 
for detecting shifts, so that we can abstain from giving an incorrect prediction if the model is not 
confident. In Section 19.4, we discuss techniques to improve robustness to shifts; in particular, given 
labeled data from p(x, y), we aim to create a model that approximates q(y|x). In Section 19.5, we 
discuss techniques to adapt the model to the target distribution given some labeled or unlabeled 
data from the target. 


19.2.1 Motivating examples 


Figure 19.1 shows how shifting the test distribution slightly, by adding a small amount of Gaussian 
noise, can hurt performance of an otherwise high accuracy image classifier. Similar effects occur 
with other kinds of common corruptions, such as image blurring [HD19]. Analogous problems 
can also occur in the text domain [Ryc+19], and the speech domain (c.f. male vs female speakers in 
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Figure 19.1: Effect of Gaussian noise of increasing magnitude on an image classifier. The model is a 
ResNet-50 CNN trained on ImageNet. From Figure 23 of [For+19]. Used with kind permission of Justin 
Gilmer. 


(A) Cow: 0.99, Pasture: 0.99, (B) No Person: 0.99, Water: 0.98, (C) No Person: 0.97, Mammal: 


Grass: 0.99, No Person: 0.98, Beach: 0.97, Outdoors: 0.97, 0.96, Water: 0.94, Beach: 0.94, Two: 
Mammal: 0.98 Seashore: 0.97 0.94 


Figure 19.2: Illustration of how image classifiers generalize poorly to new environments. (a) In the training 
data, most cows ocur on grassy backgrounds. (b-c) In these test image, the cow occurs “out of context”, namely 
on a beach. The background is considered a “spurious correlation”. In (b), the cow is not detected. In (c), it is 
classified with a generic “mammal” label. Top five labels and their confidences are produced by ClarifAI.com, 
which is a state of the art commerical vision system. From Figure 1 of [BVHP18]. Used with kind permission 


— of Sara Beery. 


Figure 34.1). These examples illustrate that high performing predictive models can be very sensitive 
to small changes in the input distribution. 

Performance can also drop on “clean” images, but which exhibit other kinds of shift. Figure 19.2 
gives an amusing example of this. In particular, it illustrates how the performance of a CNN image 
classifier can be very accurate on in-domain data, but can be very inaccurate on out-of-domain 
data, such as images with a different background, or taken at a different time or location (see e.g., 
[Koh+20b]) or from a novel viewing angle (see e.g., [KH22])). 

The root cause of many of these problems is the fact that discriminative models often leverage 
features that are predictive of the output in the training set, but which are not reliable in general. 


43 For example, in an image classification dataset, we may find that green grass in the background 
44 is very predictive of the class label “cow”, but this is not a feature that is stable across different 


distributions; these are called spurious correlations or shortcut features. Unfortunately, such 


46 features are often easier for models to learn, for reasons explained in [Gei+20a; Xia+21b; Sha+20; 
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19.2. DISTRIBUTION SHIFT 


Figure 19.3: Models for distribution shift from source s to target t. Here D} is the labeled training set from 
the source, Di, is an optional labeled training set from the target, Dt, is an optional unlabeled training set 
from the target, and Di.<,; is a labeled test set from the target. In the latter case, Ym is the prediction on 
the n’th test case (generated by the model), y}, is the true value, and ln = L(Y}, Yn) is the corresponding 
loss. (Note that we don’t evaluate the loss on the source distribution.) (a) Discriminative (causal) model. (b) 
Generative (anti-causal). 


Pez+21]. 

Relying on these shortcuts can have serious real-world consequences. For example, [Zec+18a] 
found that a CNN trained to recognize pneumonia was relying on hospital-specific metal tokens in 
the chest X-ray scans, rather than focusing on the lungs themselves, and thus the model did not 
generalize to new hospitals. 

Analogous problems arise with other kinds of ML models, as well as other data types, such as 
text (e.g., changing “he” to “she” can flip the output of a sentiment analysis system), audio (e.g., 
adding background noise can easily confuse speech recognition systems), and medical records [Ros22]. 
Furthermore, the changes to the input needed to change the output can often be imperceptible, as 
we discuss in the section on adversarial robustness (Section 19.8). 


19.2.2 A causal view of distribution shift 


In the sections below, we briefly summarize some canonical kinds of distribution shift. We adopt 
a causal view of the problem, following [Sch+12a; Zha+13b; BP16; Meil8a; CWG20; Bud+21; 
SCS822]).1 (See Section 4.7 for a brief discussion of causal DAGs, and Chapter 36 for more details.) 
We assume the inputs to the model (the covariates) are X and the outputs to be predicted (the 
labels) are Y. If we believe that X causes Y, denoted X — Y, we call it causal prediction. If we 
believe that Y causes X, denoted Y — X, we call it anti-causal prediction [Sch+ 12a]. 
The decision about which model to use depends on our assumptions about the underlying data 


1. In the causality literature, the question of whether a model can generalize to a new distribution is called the question 
of external validity. If a model is externally valid, we say that it is transportable from one distribution to another 
[BP16]. 
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Figure 19.4: Illustration of the 4 main kinds of distribution shift for a 2d binary classification problem. 
Adapted from Figure 1 of [al21]. 


generating process. For example, suppose X is a medical image, and Y is an image segmentation 
created by a human expert or an algorithm. If we change the image, we will change the annotation, 
and hence X — Y. Now suppose X is a medical image and Y is the ground truth disease state of 
the patient, as estimated by some other means (e.g., a lab test). In this case, we have Y > X, since 
changing the disease state will change the appearance of the image. As another example, suppose X 
is a text review of a movie, and Y is a measure of how informative the review is. Clearly we have 
X — Y. Now suppose Y is the star rating of the movie, representing the degree to which the user 
liked it; this will affect the words that they write, and hence Y > X. 

Based on the above discussion, we can factor the joint distribution in two possible ways. One way 
is to define a discriminative (causal) model: 


polz, Y) = Py (£)Pw (yje) (19.1) 


31 See Figure 19.3a. Alternatively we can define a generative (anti-causal) model: 


polz, Y) = pr(y)pe(zly) (19.2) 


See Figure 19.3b. For each of these 2 models model types, different parts of the distribution may 
change from source to target. This gives rise to 4 canonical type of shift, as we discuss in Section 19.2.3. 


19.2.3 The four main types of distribution shift 


The four main types of distribution shift are summarized in Section 19.2 and are illustrated in 
Figure 19.4. We give more details below. 


19.2.3.1 Covariate shift 


In a causal (discriminative) model, if pp (æ) changes (so 4° 4 w"), we call it covariate shift, also 


45 called domain shift. For example, the training distribution may be clean images of coffee pots, and 
46 the test distribution may be images of coffee pots with Gaussian noise, as shown in Figure 19.1; or the 
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19.2. DISTRIBUTION SHIFT 


Name Source Target Causal 
Covariate / domain shift p(X)p(Y|X) q(X)p(Y|X) Causal 
Concept shift p(X)p(Y|X) p(X)q(Y|X) Causal 
Label (prior) shift P(Y)p(X|Y)  g(¥)p(X|Y) Anti-causal 
Manifestation shift p(Y)p(X|Y) p(Y)q(X|Y)  Anti-causal 


Table 19.1: The 4 main types of distribution shift. 


training distribution may be photos of objects in a catalog, with uncluttered white backgrounds, and 
the test distribution may be photos of the same kinds of objects collected “in the wild”; or the training 
data may be synthetically generated images, and the test distribution may be real images. Similar 
shifts can occur in the text domain; for example, the training distribution may be movie reviews 
written in English, and the test distribution may be translations of these reviews into Spanish. 

Some standard strategies to combat covariate shift include importance weighting (Section 19.5.2) 
and domain adaptation (Section 19.5.3). 


19.2.3.2 Concept shift 


In a causal (discriminative) model, if pw(y|a) changes (so wê 4 wt), we call it concept shift, also 
called annotation shift. For example, consider the medical imaging context: the conventions for 
annotating images might be different between the training distribution and test distribution. Another 
example of concept shift occurs when a new label can occur in the target distribution that was not 
part of the source distribution. This is related to open world recognition, discussed in Section 19.3.4. 

Since concept shift is a change in what we “mean” by a label, it is impossible to fix this problem 
without seeing labeled examples from the target distribution, which defines each label by means of 
examples. 


19.2.3.3 Label / prior shift 


In an anti-causal (generative) model, if p(y) changes (i.e., m Æ 7), we call it label shift, also 
called prior shift or prevalence shift. For example, consider the medical imaging context, where 
Y =1 if the patient has some disease and Y = 0 otherwise. If the training distribution is an urban 
hospital and the test distribution is a rural hospital, then the prevalence of the disease, represented 
by p(Y = 1), might very well be different. 

Some standard strategies to combat label shift are to reweight the output of a discriminative 
classifier using an estimate of the new label distribution, as we discuss in Section 19.5.4. 


19.2.3.4 Manifestation shift 


In an anti-causal (generative) model, if pg(a|y) changes (i.e., ° 4 ġ*), we call manifestation 
shift [CWG20], or Conditional shift [Zha+13b]. This is, in some sense, the inverse of concept 
shift. For example, consider the medical imaging context: the way that the same disease Y manifests 
itself in the shape of a tumor X might be different. This is usually due to the presence of a hidden 
confounding factor that has changed between source and target (e.g., different age of the patients). 
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Figure 19.5: Causal diagrams for different sample selection strategies. Undirected edges can be oriented in 
either direction. The selection variable S is set to 1 its parent nodes match the desired criterion; only these 
samples are included in the dataset. (a) No selection. (b) Selection on X. (c) Selection on Y. (d) Selection 
on X and Y. Adapted from Figure 4 of [CWG20]. 


19.2.4 Selection bias 


In some cases, we may induce a shift in the distribution just due to the way the data is collected, 
without any changes to the underlying distributions. In particular, let S = 1 if a sample from the 
population is included in the training set, and S = 0 otherwise. Thus the source distribution is 
p(X, Y) = p(X, Y|S = 1) but the target distribution is (X,Y) = p(X, Y|S € {0,1}) = p(X, Y), so 
there is no selection. 

In Figure 19.5 we visualize the four kinds of selection. For example, suppose we select based on X 


~ meeting certain criteria, e.g., images of a certain quality, or exhibiting a certain pattern; this can 


induce domain shift or covariate shift. Now suppose we select based on Y meeting certain criteria, 
e.g., we are more likely to select rare examples where Y = 1, in order to balance the dataset (for 


~ reasons of computational efficiency); this can induce label shift. Finally, suppose we select based on 


both X and Y; this can induce non-causal dependencies between X and Y, a phenomenon known as 


. selection bias (see Section 4.2.4.2 for details). 


39 19.3 Detecting distribution shifts 


41 In general it will not be possible to make a model robust to all of the ways a distribution can shift 
42 at test time, nor will we always have access to test samples at training time. As an alternative, it 
43 may be sufficient for the model to detect that a shift has happened, and then to respond in the 


appropriate way. There are several ways of detecting distribution shift, some of which we summarize 


45 below. (See also Section 29.5.6, where we discuss changepoint detection in time series data.) The 
46 main distinction between methods is based on whether we have a set of samples from the target 
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19.3. DETECTING DISTRIBUTION SHIFTS 


distribution, or just a single sample, and whether the test samples are labeled or unlabeled. We 
discuss these different scenarios below. 


19.3.1 Detecting shifts using two-sample testing 


Suppose we collect a set of samples from the source and target distribution. We can then use standard 
techniques for two-sample testing to estimate if the null hypothesis, p(a, y) = q(x, y), is true or 
not. (If we have unlabeled samples, we just test if p(x) = q(a).) For example, we can use MMD 
(Section 2.7.3) to measure the distance between the set of input samples (see e.g., [Liu+20a]). Or we 
can measure (Euclidean) distances in the embedding space of a classifier trained on the source (see 
e.g., [KM22]). 

In some cases it may be possible to just test if the distribution of the labels p(y) has changed, which 
is an easier problem than testing for changes in the distribution of inputs p(a). In particular, if the 
label shift assumption (Section 19.2.3.3) holds (i.e., q(æ|y) = p(æ|y)), plus some other assumptions, 
then we can use the blackbox shift estimation technique from Section 19.5.4 to estimate q(y). If 
we find that q(y) = p(y), then we can conclude that q(a,y) = p(x, y). In [RGL19], they showed 
experimentally that this method worked well for detecting distribution shifts even when the label 
shift assumption does not hold. 

It is also possible to use conformal prediction (Section 14.3) to develop “distribution free” methods 
for detecting covariate shift, given only acccess to a calibration set and some conformity scoring 
function [HL20]. 


19.3.2 Detecting single out-of-distribution (OOD) inputs 


Now suppose we just have one unlabeled sample from the target distribution, x ~ q, and we want to 
know if æ is in-distribution (ID) or out-of-distribution (OOD). We will call this problem out-of- 
distribution detection, although it is also called anomaly detection, and novelty detection.” 

The OOD detection problem requires making a binary decision about whether the test sample is ID 
or OOD. If it is ID, we may optionally require that we return its class label, as shown in Figure 19.6. 
In the sections below, we give a brief overview of techniques that have been proposed for tackling 
this problem, but for more details, see e.g., [Pan+21; Ruf+21; Bul+20; Yan+21; Sal+21; Hen+19b]. 


19.3.2.1 Supervised ID/OOD methods (outlier exposure) 


The simplest method for OOD detection assumes we have access to labeled ID and OOD samples at 
training time. Then we just fit a binary classifier to distinguish the OOD or background class (called 
“known unknowns’) from the ID class (called “known knowns”) This technique is called outlier 
exposure (see e.g., [HMD19; Thu+21; Bit+21]) and can work well. However, in most cases we will 
not have enough examples from the OOD distribution, since the OOD set is basically the set of all 
possible inputs except for the ones of interest. 


2. The task of outlier detection is somewhat different from anomaly or OOD detection, despite the similar name. 
In the outlier detection literature, the assumption is that there is a single unlabeled dataset, and the goal is to identify 
samples which are “untypical” compared to the majority. This is often used for data cleaning. (Note that this is a 
transductive learning task, where the model is trained and evaluated on the same data. We focus on inductive 
tasks, where we train a model on one dataset, and then test it on another.) 
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Figure 19.6: Illustration of a two-stage decision problem. First we must decide if the input image is out-of- 
distribution (OOD) or not. If it is not, we must return the set of class labels that have high probabilitiy. From 
[AB21]. Used with kind permission of Anastasios Angelopoulos. 


19.3.2.2 Classification confidence methods 


Instead of trying to solve the binary ID/OOD classification problem, we can directly try to predict 
the class of the input. Let the probabilities over the C labels be pe = p(y = c|x), and let the logits be 
Lle = log pe. We can derive a confidence score or uncertainty metric in a variety of ways from 
these quantities, e.g., the max probability s = max, pe, the margin s = max, le — max? le (where 
max” means the second largest element), the entropy s = H(p1.c)°, the “energy score” s = >>, l 
[Liu+21b], etc. In [Mil+21; Vaz+22] they show that the simple max probability baseline performs 
very well in practice. 


19.3.2.3 Conformal prediction 


It is possible to create a method for OOD detection and ID classification that has provably bounded 
risk using conformal prediction (Section 14.3). The details are in [Ang+21], but we sketch the basic 
idea here. 

We want to solve the two-stage decision problems illustrated in Figure 19.6. We define the 


34 prediction set as follows: 


0 if OOD(x) > Ay 


19.3 
APS(a) otherwise oe 


39 where OOD(a) is some heuristic OOD score (such as max class probability), and APS(a) is the 
40 adaptive prediction set method of Section 14.3.1, which returns the set of the top K class labels, 
41 such that the sum of their probabilities exceeds threshold 2. (Formally, APS(a) = {m,...,7K} 


where m sorts f(æ)ı:c in descending order, and K = min{ K’ : Dhi flæ)e > à2}-) 
We choose the thresholds A; and Az using a calibration set and a frequentist hypothesis testing 


= 3. [Kir+21] argues against using entropy, since it confuses uncertainty about which of the C labels to use with 


46 uncertainty about whether any of the labels is suitable, compared to a “none-of-the-above” option. 
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19.3. DETECTING DISTRIBUTION SHIFTS 
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Figure 19.7: Likelihoods from a Glow normalizing flow model (Section 23.2.1) trained on CIFAR10 and 
evaluated on different test sets. The SVHN street sign dataset has lower visual complexity, and hence higher 
likelihood. Qualitatively similar results are obtained for other generative models and data set. From Figure 1 
of [Ser+20]. Used with kind permission of Joan Serra. 


method (see [Ang-+21]). The resulting thresholds will jointly minimize the following risks: 


where p(x, y) is the true but unknown source distribution (of ID samples, no OOD samples required), 
R, is the chance that an ID sample will be incorrectly rejected as OOD (type-I error), and Ro is the 
chance (conditional on the decision to classify) that the true label is not in the predicted set. The goal 
is to set A; as large as possible (so we can detect OOD examples when they arise) while controlling 
the type-I error (e.g., we may want to ensure that we falsely flag (as OOD) no more than 10% of 
in-distribution samples). We then set 2 in the usual way for the APS method in Section 14.3.1. 


19.3.2.4 Unsupervised methods 


If we don’t have labeled examples, a natural approach to OOD detection is to fit an unconditional 
density model (such as a VAE) to the ID samples, and then to evaluate the likelihood p(x) and 
compare this to some threshold value. Unfortunately for many kinds of deep model and datasets, we 
sometimes find that p(a) is lower for samples that are from the source distribution than from a novel 
target distribution. For example, if we train a pixel-CNN model (Section 22.3.2) or a normalizing-flow 
model (Chapter 23) on Fashion-MNIST and evaluate it on MNIST, we find it gives higher likelihood 
to the MNIST samples [Nal+-19a; Ren+19; KIW20; ZGR21]. This phenomenon occurs for several 
other models and datasets (see Figure 19.7). 

One solution to this is to use log a likelihood ratio relative to a baseline density model, R(x) = 
log p(a)/q(x), as opposed to the raw log likelihood, L(x) = log p(x). (This technique was explored 
in [Ren+19], amongst other papers.) An important advantage of this is that the ratio is invariant to 
transformations of the data. To see this, let x’ = @(a) be some invertible, but possibly nonlinear, 
transformation. By the change of variables, we have p(a’) = p(a)| det Jac(@~')(a)|. Thus L(a’) will 
differ from L(a) in a way that depends on the transformation. By contrast, we have R(x) = R(x’), 
regardless of @, since 


R(a’) = log p(x’) — log q(a’) = log p(x) + log | det Jac(@~ *)(a)| — log q(x) — log | det Jac(@ *)(a)| 
(19.6) 
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Various other strategies have been proposed, such as computing the log-likelihood adjusted by a 
measure of the complexity (coding length computed by a lossless compression algorithm) of the input 
[Ser+20], computing the likelihood of model features instead of inputs [Mor+21a], etc. 

A closely related technique relies on reconstruction error. The idea is to fit an autoencoder or 
VAE (Section 21.2) to the ID samples, and then measure the reconstruction error of the input: a 
sample that is OOD is likely to incur larger error (see e.g. [Pol+19]). However, this suffers from the 
same problems as density estimation methods. 

An alternative to trying to estimate the likelihood, or reconstruct the output, is to use a GAN 
(Chapter 26) that is trained to discriminate “real” from “fake” data. This has been extended to the 
open set recognition setting in the OpenGAN method of [KR21b]. 


19.3.3 Selective prediction 


Suppose the system has a confidence level of p that an input is OOD (see Section 19.3.4 for a 
discussion of some ways to compute such confidence scores). If p is below some threshold, the system 
may choose to abstain from classifying it with a specific label. By varying the threshold, we can 
control the tradeoff between accuracy and abstention rate. This is called selective prediction (see 
e.g., [EW10; GEY19; Ziy+19; JKG18]), and is useful for applications where an error can be more 
costly than asking a human expert for help (e.g., medical image classification). 


19.3.3.1 Example: SGLD vs SGD for MLPs 


One way to improve performance of OOD detection is to “be Bayesian” about the parameters of the 
model, so that the uncertainty in their values is reflected in the posterior predictive distribution. 


27 This can result in better performance in selective prediction tasks. 


In this section, we give a simple example of this, where we fit a shallow MLP to the MNIST dataset 


29 using either standard SGD (specifically RMSprop) or Stochastic Gradient Langevin Dynamics (see 
30 Section 12.7.1), which is a form of MCMC inference. We use 6,000 training steps, where each step 
31 uses a minibatch of size 1,000. After fitting the model to the training set, we evaluate its predictions 
32 on the test set. To assess how well calibrated the model is, we select a subset of predictions whose 
33 confidence is above a threshold t. (The confidence value is just the probability assigned to the MAP 
34 class.) As we increase the threshold t from 0 to 1, we make predictions on fewer examples, but the 
35 accuracy should increase. This is shown in Figure 19.8: the green curve is the fraction of the test set 
36 for which we make a prediction, and the blue curve is the accuracy. On the left we show SGD, and 
37 on the right we show SGLD. In this case, performance is quite similar, although SGD has slightly 
38 higher accuracy. However, the story changes somewhat when there is distribution shift. 


To study the effects under distriution shift, we apply both models to FashionMNIST data. We 


40 show the results in Figure 19.9. The accuracy of both models is very low (less than the chance level 
41 of 10%). but SGD remains quite confident in many more of its predictions than SGLD, which is 


more conservative. To see this, consider a confidence threshold of 0.5: the SGD approach predicts on 


43 about 97% of the examples (recall that the green curve corresponds to the right hand axis), whereas 
44 the SGLD only predicts on about 70% of the examples. 


More details on the behavior of Bayesian neural networks under distribution shift can be found in 


46 Section 17.4.6.2. 
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Figure 19.8: Accuracy vs confidence plots for an MLP fit to the MNIST training set, and then evaluated on 
one batch from the MNIST test set. Scale for blue accuracy curve is on the left, scale for green percentage 
predicted curve is on the right. (a) Plugin approach, computed using SGD. (b) Bayesian approach, computed 
using 10 samples from SGLD. Generated by bnn_mnist_sgld.ipynb. 
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Figure 19.9: Similar to Figure 19.8, except that performance is evaluated on the Fashion MNIST dataset. (a) 
SGD. (b) SGLD. Generated by bnn_ mnist_ sgld.ipynb. 


19.3.4 Open set and open world recognition 


In Section 19.3.3, we discussed methods that “refuse to classify” if the system is not confident enough 
about its predicted output. If the system detects that this lack of confidence is due to the input 
coming from a novel class, rather than just being a novel instance of an existing class, we call the 
problem open set recognition (see e.g., [GHC20] for a review). 

Rather than “flagging” novel classes as OOD, we can instead allow the set of classes to grow over 
time; this is called open world classification [BB15a]. Note that open world classification is most 
naturally tackled in the context of a continual learning system, which we discuss in Section 19.7.3. 


19.4 Robustness to distribution shifts 


In this section, we discuss techniques to improve the robustness of a model to distribution shifts. 
In particular, given labeled data from p(x, y), we aim to create a model that approximates q(y|x). 
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19.4.1 Data augmentation 


A simple approach to potentially increasing the robustness of a predictive model to distribution shifts 
is to simulate samples from the target distribution by modifying the source data. This is called data 
augmentation, and is widely used in the deep learning community. For example, it is standard to 
apply small perturbations to images (e.g., shifting them or rotating them), while keeping the label 
the same (assuming that the label should be invariant to such changes); see e.g., [SK19; Hen+20] for 
details. Similarly, in NLP (natural language processing), it is standard to change words that should 
not affect the label (e.g., replacing “he” with “she” in a sentiment analysis system), or to use back 
translation (from a source language to a target language and back) to generate paraphrases; see 
e.g., [Fen+21] for a review of such techniques. For a causal perspective on data augmentation, see 
e.g., [Kau+21]. 


19.4.2 Distributionally robust optimization 


We can make a discriminative model that is robust to (some forms of) covariate shift by solving the 
following distributionally robust optimization (DRO) problem: 


N 
. 1 
min max +; > Unk(f (En), Yn) (19.7) 


where the samples are from the source distribution, (£n, Yn) ~ p. This is an example of a min-max 
optimization problem, in which we want to minimize the worst case risk. The specification of the 
robustness set, W, is a key factor that determines how well the method works, and how difficult the 
optimization problem is. Typically it is specified in terms of an fy ball around the inputs, but this 
could also be defined in a feature (embedding space) It is also possible to define the robustness set 
in terms of local changes to a structural causal model [Meil8a]. For more details on DRO, see e.g., 


~ [WYG14; Hu+18; CP20a; LFG21; Sag+20]. 


19.5 Adapting to distribution shifts 


39 In this section, we discuss techniques to adapt the model to the target distribution. If we have some 


labeled data from the target distribution, we can use transfer learning, as we discuss in Section 19.5.1. 
However, getting labeled data from the target distribution is often not an option. Therefore, in the 
other sections, we discuss techniques that just rely on unlabeled data from the target distribution. 


19.5.1 Supervised adaptation using transfer learning 


— Suppose we have labeled training data from a source distribution, D: = { (£n, yn) ~ p:n=1: Ns}, 
~~ and also some some labeled data from the target distribution, Dt = {(an, yn) ~q:n=1: Ni}. Our 
~ goal is to minimize the risk on the target distibution q, which can be computed using 


R(f, q) = ERN] (y, f(x))] (19.8) 
We can approximate the risk empirically using 
a 1 
RUDY = y, 2o Uum F(@n)) (19.9) 
(En; Yn) ED 
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19.5. ADAPTING TO DISTRIBUTION SHIFTS 


If Dt is large enough, we can directly optimize this using standard empirical risk minimization (ERM). 
However, if D* is small, we might want to use D: somehow as a regularizer. This is called transfer 
learning, since we hope to “transfer knowledge” from p to q. There are many approaches to transfer 
learning (see e.g., [Zhu+21] for a review). We briefly mention a few below. 


19.5.1.1 Pre-train and fine-tune 


The simplest and most widely used approach to transfer learning is the pre-train and fine-tune 
approach. We first fit a model to the source distribution by computing f* = argmin p R(f,D*). (Note 
that the source data may be unlabeled, in which case we can use self-supervised learning methods.) 
We then adapt the model to work on the target distribution by computing 


ft =argmin R(f,D') + Allf — fll (19.10) 
f 


where || f — f*|| is some distance between the functions, and \ > 0 controls the degree of regularization. 

Since we assume that we have very few samples from the target distribution, we typically “freeze” 
most of the parameters of the source model. (This makes an implicit assumption that the features that 
are useful for the source distribution also work well for the target.) We can then solve Equation (19.10) 
by “chopping off the head” from f* and replacing it with a new linear layer, to map to the new set of 
labels for the target distribution, and then compute a new MAP estimate for the parameters on the 
target distribution. (We can also compute a prior for the parameters of the source model, and use it 
to compute a posterior for the parameters of the target model, as discussed in Section 17.2.3.) 

This approach is very widely used in practice, since it is simple and effective. In particular, it is 
common to take a large pre-trained model, such as a transformer, that has been trained (often using 
self supervised learning, Section 32.3.3) on a lot of data, such as the entire web, and then to use this 
model as a feature extractor (see e.g., [Kol+-20]). The features are fed to the downstream model, 
which may be a linear classifier or a shallow MLP, which is trained on the target distribution. 


19.5.1.2 Prompt tuning (in-context learning) 


Recently another approach to transfer learning has been developed, that leverages large models, such 
as transformers (Section 22.4), which are trained on massive web datasets, usually in an unsupervised 
way, and then adapted to a small, task-specific target distribution. The interesting thing about 
this approach is the parameters of the original model are not changed; instead, the model is simply 
“conditioned” on new training data, usually in the form of a text prompt z. That is, we compute 


f'(z) = fF (xu z) (19.11) 


where we (manually or automatically) optimize z while keeping f* frozen. This approach is called 
prompt tuning or in-context learning (see e.g., [Liu+2la]), and is an instance of few-shot 
learning (see Figure 22.5 for an example). 

Here z acts like a small training dataset, and f° uses attention (Section 16.2.7) to “look at” all 
its inputs, comparing x with the examples in z, and uses this to make a prediction. This works 
because the text training data often has a similar hierarchical structure (see [Xie+22] for a Bayesian 
interpretation). 
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— One approach to estimating the ERM weights wn = w(£n) = 
~ the source and target. However, density esimation is difficult for high dimensional features. An 
~ alternative approach is to try to approximate the density ratio, by fitting a binary classifier to 
~ distinguish the two distributions, as discussed in Section 2.7.5. In particular, suppose we have an 
~, equal number of samples from p(x) and g(a). Let us label the first set with c = —1 and the second 
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19.5.2 Weighted ERM for covariate shift 


In this section we reconsider the risk minimization objective in Equation (19.8), but leverage unlabeled 
data from the target distribution to estimate it. If we make the covariate shift assumption (i.e., 
q(æ, y) = q(x)p(yla)), then we have 


R(f.a) = | ale)alule)elw, fl@))dedy (19.12) 
= | aæ)plyle) ty, F(e))dedy (19.13) 
= | ole ywtuie)tty, £@))dedy (19.14) 
wa walla: flen) (19.15) 


(@n,Yn)EDF 


where the weights are given by the ratio 


Wn = W(an) = (19.16) 
Thus we can solve the covariate shift problem by using weighted ERM [Shi00a; SKMO7]. 
However, this raises two questions. First, why do we need to use this technique, since a discriminative 
model p(y|a) should work for any input x, regardless of which distribution it comes from. Second, 
given that we do need to use this method, in practice how should we estimate the weights wn = 


w(&n) = az, We discuss these issues below. 


28 19.5.2.1 Why is covariate shift a problem for discriminative models? 


For a discriminative model of the form p(y|x), it might seem that such a change in p(x) will not 
affect the predictions. If the predictor p(y|x) is the correct model for all parts of the input space x, 


oo then this conclusion is warranted. However, most models will only be accurate in certain parts of the 


input space. This is illustrated in Figure 19.10b, where we show that a linear model fit to the source 


2, distribution may perform much worse on the target distribution than a model that weights target 


points more heavily during training. 


19.5.2.2 How should we estimating the ERM weights? 
q(en) 


(En) 


is to learn a density model for 


~ set with c = 1. Then we have 


q(x) 
) + p(x) 


ple = lla) = TE (19.17) 
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(a) 


Figure 19.10: (a) Illustration of covariate shift. Light gray represents training distribution, dark gray 
represents test distribution. We see the test distribution has shifted to the right but the underlying input-output 
function is constant. (b) Dashed line: fitting a linear model across the full support of X. Solid black line: 
fitting the same model only on parts of input space that have high likelihood under the test distribution. From 
Figures 1-2 of [Sto09]. Used with kind permission of Amos Storkey. 


and hence at = a If the classifier has the form f(a) = p(c = l|x) = o(h(a)) = IS RE 
where h(a) is the prediction function that returns the logits, then the importance weights are given 


by 


= 1/(1 + exp(—h(@n))) 
“~ exp(—h(2n))/(1 + exp(—h(an))) 
Of course this method requires that x values that may occur in the test distribution should also 


be possible in the training distribution, i.e. g(a) > 0 = > p(x) > 0. Hence there are no guarantees 
about this method being able to interpolate beyond the training distribution. 


= exp(h(an)) (19.18) 


19.5.3 Unsupervised domain adaptation for covariate shift 


We now turn to methods that only need access to unlabeled examples from the target distribution. 

The technique of unsupervised domain adaptation or UDA assumes access to a labeled 
dataset from the source distribution, Dı = Dj, ~ p(a, y) and an unlabeled dataset from the target 
distribution, D2 = Dj; ~ q(a). It then uses the unlabeled target data to improve robustness or 
invariance of the predictor, rather than using a weighted ERM method. 

There are many forms of UDA (see e.g., [KL21; CB20] for reviews). Here we just focus on one 
method, called domain adversarial learning [Gan+16a]. Let fa : ¥ı U X2 —> H be a feature 
extractor defined on the two input domains, let cg : H — {1,2} be a classifier that maps from the 
feature space to the domain from which the input was taken, either domain 1 or 2 (source or target), 
and let gy : H —> Y be a classifier that maps from the feature space to the label space. We want 
to train the feature extractor so that it cannot distinguish whether the input is coming from the 
source or target distribution; in this case, it will only be able to use features that are common to 
both domains. Hence we optimize 


: 1 1 
rae 8 Ni 4 N2 5 Hd Gal falEn))) F Nı 5 L(Yns Gy (fa(#n))) (19.19) 
@nED1,D2 (@n,Yn)ED1 


The objective in Equation (19.19) minimizes the loss on the desired task of classifying y, but mazimizes 
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the loss on the auxiliary task of classifying the domain label d. This can be implemented by the 
gradient sign reversal trick, and is related to GANs (Section 26.7.6). 


19.5.4 Unsupervised techniques for label shift 


In this section, we describe an approach known as black box shift estimation, due to [LWS18], 

which can be used to tackle the label shift problem in an unsupervised way. We assume that the 

only thing that changes in the target distribution is the label prior, i.e., if the source distribution is 

denoted by p(x, y) and target distribution is denoted by q(x, y), we assume q(x, y) = p(aly)q(y). 
First note that, for any deterministic function f : ¥ —> VY, we have 


p(aly) = q(aly) => v(f(z)ly) = 4a(f(x)ly) = ply) = aly) (19.20) 


where ĝ = f(a) is the predicted label. Let u; = q(ĝ = i) be the empirical fraction of times the model 
predicts class 7 on the test set, and let g(y = i) be the true but unknown label distribution on the 
test set, and let Ci; = p(y = i|y = j) be the class confusion matrix estimated on the training set. 
Then we have 


ua = X aCaly)al = Lvl (aly)a = LP (ay) (19.21) 


F p(y) 
We can write this in matrix-vector form as follows: 


Hence we can solve q = C7! p, providing that C is not singular (this will be the case if C is strongly 
diagonal, i.e., the model predicts class y; correctly more often than any other class y;). We also 


28 require that for every q(y) > 0 we have p(y) > 0, which means we see every label at training time. 


Once we know the new label distribution, g(y), we can adjust our discriminative classifier to take 


30 the new label prior into account as follows: 


q(alyja(y) _ plalyjay) _ plylw)p(@) ay) _ p(yjar) LY) P@) (19.23) 


q(x) q(x) ply) g(a) ply) q(x 


We can safely ignore the ue term, which is constant wrt y, and we can plug in our estimates of the 


aly) 


q(yla) = 


nN 


36 label distributions to compute the pty)" 


In summary, there are three requirements for this method: (1) The confusion matrix is invertible; 


38 (2) no new labels at test time; (3) the only thing that changes is the label prior. If these three 
39 conditions hold, the above approach is a valid estimator. See [LWS18] for more details, and [Gar+20] 
40 for an alternative approach, based on maximum likelihood (rather than moment matching) for 
= estimating the new marginal label distribution. 


~ 19.5.5 Test-time adaptation 


45 In some settings, it is possible to continuously update the model parameters. This allows the model 
46 to adapt to changes in the input distribution. This is called test time adaptation or TTA. The 
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difference from the unsupervised domain adaptation methods of Section 19.5.3 is that, in the online 
setting, we just have the model which was trained on the source, and not to the source distribution. 

In [Sun+20] they proposed an approach called TTT (‘“test-time training”) for adapting a discrimi- 
native model. In this approach, a self-supervised proxy task is used to create pseudo-labels, which 
can then be used to adapt the model at run time. In more detail, suppose we create a Y-structured 
network, where we first perform feature extraction, x — h, and then use h to predict the output y 
and some proxy output r, such as the angle of rotation of the input image. The rotation angle is 
known if we use data augmentation. Hence we can apply this technique at test time, even if y is 
unknown, and update the x + h > r part of the network, which influences the prediction for y via 
the shared bottleneck (feature layer) h. 

Of course, if the proxy output, such as the rotation angle, is not known, we cannot use proxy- 
supervised learning methods such as TTT. In [Wan+20a], they propose an approach, inspired by 
semi-supervised learning methods, which they call call TENT, which stands for “test-time adaptation 
by entropy minimization”. The idea is to update the classifier parameters to minimize the entropy of 
the predictive distribution on a batch of test examples. In [Goy+22], they give a justification for this 
heuristic from the meta-learning perspective. In [ZL21], they present a Bayesian version of TENT, 
which they call BACS, which stands for “Bayesian adaptation under covariate shift”. In [ZLF21], 
they propose a method called MEMO (“marginal entropy minimization with one test point”) that 
can be used for any architecture. The idea is, once again, to apply data augmentation at test time to 
the input æ, to create a set of inputs, %1,...,% g. Now we update the parameters so as to minimize 
the predictive entropy produced by the averaged distribution 


1 B 
Plyla, w) = 5 > plul, w) (19.24) 
b=1 


This ensures that the model gives the same predictions for each perturbation of the input, and that 
the predictions are confident (low entropy). 

An alternative to entropy based methods is to use pseudo-labels (predicted outputs from the source 
model), and then to self-train on these. This has been shown to work well for unsupervised domain 
adaptation [KML20]. In the context of TTA, [Che+22] use this technique, combined with contrastive 
learning, to adapt a classifier to a new distribution. 


19.6 Learning from multiple distributions 


In Section 19.2, we discussed the setting in which a model is trained on a single source distribution, 
and then evaluated on a distinct target distribution. In this section, we generalize this to a setting in 
which the model is trained on data from J > 2 source distributions, before being tested on data from 
a target distribution. This includes a variety of different problem settings, depending on the value of 
J, as we summarize in Figure 19.11. 


19.6.1 Multi-task learning 


In multi-task learning (MTL) [Car97], we have labeled data from J different distributions, DI = 
{(x}, yi) : n = 1: Nj}, and the goal is to learn a model that predicts well on all J of them 
simultaneously, where f(ax,j): 4 — V; is the output for the j’th task. For example, we might want 
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Figure 19.11: Schematic overview of techniques for learning from 1 or more different distributions. Adapted 
from slide 3 of [Sca21]. 


Source model + 
unsupervised target 
data 


n>2 
distributions 


Target distribution is 
unavailable 


to map a color image of size H x W x 3 to a set of semantic labels per pixel, Yt = {1,...,C}“, as 
well as a set of predicted depth values per pixel, Y? = RĦW. We can do this using ERM where we 


= have multiple samples for each task: 
J Ni 
F” = argmin ) |) (th FE) (19.25) 
j=1n=1 


“+ where @; is the loss function for task j (suitably scaled). 


There are many approaches to solving MTL. The simplest is to fit a single model with multiple 


233 “output heads”, as illustrated in Figure 19.12. This is called a “shared trunk network”. Unfortu- 
34 nately this often leads to worse performance than training J single task networks. In [Mis+16], they 
22 propose to take a weighted combination of the activations of each single task network, an approach 
2 they called “cross-stitch networks”. See [ZY21] for a more detailed review of neural approaches, 
* and [BLS11] for a theoretical analysis of this problem. 


Note that multi-task learning does not always help performance on each task because sometimes 


32 there can be “task interference” or “negative transfer” (see e.g., |MAP17; Sta+20; WZR20]). 
= In such cases, we should use separate networks, rather than using one model with multiple output 


2 heads. 


= 19.6.2 Domain generalization 


45 The problem of domain generalization assumes we train on J different labeled source distributions 
46 or “environments” (also called “domains”), and then test on a new target distribution (denoted by 
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-MIKS 


pre areni parameters 


task specific parameters 


Figure 19.12: Illustration of multi-headed network for multi-task learning. 


Figure 19.13: Hierarchical Bayesian discriminative model for learning from J different environments (distri- 
butions), and then testing on a new target distribution t = J + 1. Here Yn is the prediction for test example 
Zn, Y is the true output, and ln = L(y’, yx) is the associated loss. The parameters of the distribution over 
input features pg(a) are shown with dotted edges, since these distributions do not need to be learned in a 
discriminative model. 
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environment e = 1: environment e = 2: environment e = 3: 
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Figure 19.14: Illustration of invariant causal prediction. The hammer symbol represents variables whose 
distribution is perturbed in the given environment. An invariant predictor must use features {X2, X4}. 
Considering indirect causes instead of direct ones (e.g. {X2,X5}) or an incomplete set of direct causes (e.g. 
{X4}) may not be sufficient to guarantee invariant prediction. From Figure 1 of [PBM16b]. Used with kind 
permission of Jonas Peters. 


t = J +1). In some cases each environment is just identified with a meaningless integer id. In more 
realistic settings, each different distribution has associated meta-data or context variables that 
characterizes the environment in which the data was collected, such as the time, location, imaging 
device, etc. 

Domain generalization (DG) is similar to multi-task learning, but differs in what we want to 
predict. In particular, in DG, we only care about prediction accuracy on the target distribution, not 
the J training distribution. Furthermore, we assume we don’t have any labeled data from the target 
distribution. We therefore have to make some assumptions about how p‘(a, y) relates to p’ (a, y) for 
J=l: J. 

One way to formalize this is to create a hierarchical Bayesian model, as proposed in [Bax00], and 
illustrated in Figure 19.13. This encodes the assumption that pt (æ, y) = p(a|@")p(y|a, wt) where 


wt is derived from a common “population level” model w®, shared across all distributions, and 


31 similarly for ø*. (Note, however, that in a discriminative model, we don’t need to model p(a|¢").) 


See Section 15.5 for discussion of hierarchical Bayesian GLMs, and Section 17.6 for discussion of 


33 hierarchical Bayesian MLPs. 


Many other techniques have been proposed for DG. Note, however, that [GLP21] found that none 


35 of these methods worked consistently better than the baseline approach of performing empirical 


risk minimization across all the provided datasets. For more information, see e.g., [GLP21; She+21; 


37 Wan+21; Chr+21]. 


= 19.6.3 Invariant risk minimization 


41 One approach to domain generalization that has received a lot of attention is called invariant 


risk minimization or IRM [Arj+19]. The goal is to learn a predictor that works well across all 


43 environments, yet is less prone to depending on the kinds of “spurious features” we discussed in 
44 Section 19.2.1. 


IRM is an extension of an earlier method called invariant causal prediction (ICP) [PBM16b]. 


46 This uses hypothesis testing methods to find the set of predictors (features) that directly cause the 
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outcome in each environment, rather than features that are indirect causes, or are just correlated 
with the outcome. See Figure 19.14 for an illustration. 

In [Arj+19], they proposed an extension of ICP to handle the case of high dimensional inputs, 
where the individual variables do not have any causal meaning (e.g., they correspond to pixels). 
Their approach requires finding a predictor that works well on average, across all environments, while 
also being optimal for each individual environment. That is, we want to find 


= Sgn S L(y? , f (x?,)) (19.26) 


FEF j=l n=1 
Nj 
such that f € arg min +; — J“ yl, g(xi)) for all j € E (19.27) 


where € is the set of environments, and F is the set of prediction functions. The intuition behind 
this is as follows: there may be many functions that achieve low empirical loss on any given 
environment, since the problem may be underspecified, but if we pick the one that also works well on 
all environments, it is more likely to rely on causal features rather than spurious features. 

Unfortunately, more recent work has shown that the IRM principle often does not work well for 
covariate shift, both in theory [RRR21] and practice [GLP21], although it can work well in some 
anti-causal (generative) models [Ahu+21]. 


19.6.4 Meta-learning 


The goal of meta-learning is to “learn the learning algorithm” [TP97]. A common way to do this is 
to provide the meta-learner with a set of datasets from different distributions. This is very similar 
to domain generalization (Section 19.6.2), except that we partition each training distribution into 
training and test, so we can “practice” learning to generalize from a training set to a test set. A 
general review of meta-learning can be found in [Hos+20a]. Here we present a unifying summary 
based on the hierarchical Bayesian framework proposed in [Gor+19]. 


19.6.4.1 Meta-learning as probabilistic inference for prediction 


We assume there are J tasks (distributions), each of which has a training set Di ain = {(24, yf) : 
n=1:N%} and a test set Dj. = {(&2,,93,) :m =1: MÍ}. In addition, wÍ are the task specific 
parameters, and w? are the shared parameters, as shown in Figure 19.15. This is very similar to 
the domain generalization model in Figure 19.13, except for two differences: first there is the trivial 
difference due to the use of plate notation; second, in meta learning, we have both training and test 
partitions for all distributions, whereas in DG, we only have a test set for the target distribution. 

We will learn a point estimate for the global parameters w°, since it is shared across all datasets, 
and thus has little uncertainty. However, we will compute an approximate posterior for wÏ, since 
each task often has little data. We denote this posterior by p(w |D} aim, w?). From this, we can 
compute the posterior predictive distribution for each task: 


train? 


Py ee Dj ain W w°) = f pač, wp (wD ain W w°)dw (19.28) 
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Figure 19.15: Hierarchical Bayesian model for meta-learning. There are J tasks, each of which has a training 
set D? = {(x1,,y,):n=1: N’} and a test set Dig = { (Zh, Yh): M=1: M’}. w are the task specific 
parameters, and @ are the shared parameters. Adapted from Figure 1 of [Gor+19]. 


Since computing the posterior is in general intractable, we will learn an amortized approximation 
(see Section 10.3.6) to m predictive distribution, denoted by q (YÍ |2, Dain WP). We choose the 


train? 


parameters of the prior w° and the inference network @ to make this predictive posterior as accurate 
as possible for any given input dataset: 
p= argmin Up (Drain) [Dex (v(g|z, Dtrain, w°) |l qoll? Dirain, w°))] (19.29) 
$ 
= ee Vp(Derain,&) [ Ly (§|#,Derain ww) [log MAIER Dirain, w°)]] (19.30) 
= argmin BP(Dirain Ž, Ü) toe f palë, w)aslwlPunn toe (19.31) 
$ 


32 where we made the approximation p(Ņ|£, Dirain, W?) © p(Yy|Z, Ptrain). We can then make a Monte 
33 Carlo approximation to the outer expectation by sampling J tasks (distributions) from p(D), each 
34 of which gets partitioned into a train and test set, {(D) aim: Drest) ~ P(D) : j = 1: J}, where 
3 DI... = {(&m; Ym}. We can make an MC approximation to the inner expectation (the integral) by 
= drawing S samples from the task-specific parameter posterior wi ~ qg(w|DJ,w®). The resulting 


objective has the following form (where we assume each test set has M samples for notational 


= simplicity): 


J S 
Lmetalw ° d= aw 3 Slog GE >p (92,|22,; wb) (19.32) 


m=1 j=1 


Note that this is different from standard (amortized) variational inference, that focuses on ap- 


45 proximating the expected accuracy of the parameter posterior given all of the data for a task, 


J J 
D = Dhrain 


a = U D?st: rather than focusing on predictive accuracy of a test set given a training set. 
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Indeed, the standard objective has the form 


J S 
1 1 PrE o ; 
Lviw’,d)=5 E 5 Depala wt] - De (alw Din, w°) || p(w! |w°)) 
j=l (@,y)ED!, s=1 
(19.33) 


where wÍ ~ qg(w/|D!,,). We see that the standard formulation takes the average of a log, but the 
meta-learning formulation takes the log of an average. The latter can give provably better predictive 
accuracy, as pointed out in [MAD20]. Another difference is that the meta-learning formulation 
optimizes the forward KL, not reverse KL. Finally, in the meta-learning formulation, we do not have 
the KL penalty term on the parameter posterior. 

Below we show how this framework includes several common approaches to meta-learning. 


19.6.4.2 Neural processes 


In the special case that the task-specific inference network computes a point estimate, q(w/|DJ, w?) = 
(wi — Ag(DI, w?)), the posterior predictive distribution becomes 


a(g |’, DI, w?) = | PEE waw, dw = p(y" |$, Ap(DI, w°), w°) (19.34) 


where Ag(D’, w?) is a function that takes in a set, and returns some parameters. We can evaluate 
this predictive distribution empirically, and directly optimize it (wrt @ and w?) using standard 
supervised maximum likelihood methods. This approach is called a neural process [Gar-+18e; 
Gar+18d]. (See [DGF20] for a good tutorial.) 

However, a more common approach to meta-learning is to specify the form of the estimator 
Ag(D!, w?) using various heuristics, some of which we discuss below. 


19.6.4.3 Gradient-based meta-learning (MAML) 


In gradient-based meta-learning, we define the task specific inference procedure as follows: 


Ni 
tb) = A(D!,w°) = w? + 1Vw log X p(y). læ}, w) wo (19.35) 


n=1 


That is, we set the task specific parameters to be shared parameters w?, modified by one step along the 
gradient of the log conditional likelihood. This approach is called model-agnostic meta-learning 
or MAML [FAL17]. It is also possible to take multiple gradient steps, by feeding the gradient into 
an RNN [RL17]. 


19.6.4.4 Metric-based few-shot learning (prototypical networks) 
Now suppose w° correspond to the parameters of a shared neural feature extractor, h,,o(a), and 


the task specific parameters are the weights and biases of the last linear layer of a classifier, 
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wi = {wi,bi}°_,. Let us compute the average of the feature vectors for each class in each task’s 
training set: 


; 1 
p= Y halet) (19.36) 


J 
[Del zs eDi 


Now define the task specific inference procedure as follows. We first compute the vector containing 
the centroid and norm for each class: 


, i , A 
w! = A(D!, w°) = [d, -5 lleil (19.37) 
The predictive distribution becomes 
i TEE A i ea ee 
aly = c$, DI, w°) x exp (—d(hwo(@), 142) = exp | hwo (2)' me — 5 Ile ll (19.38) 


where d(u, v) is the Euclidean distance. This is equivalent to the technique known as prototypical 
networks [SSZ17]. 


19.7 Continual learning 


In this section, we discuss continual learning (see e.g., [Had-+-20; Del+21; Qu+21; LCR21; Mai+22; 
Lin+22]), also called life-long learning (see e.g., [Thr98; CL18]), in which the system learns from a 


sequence of different distributions, pı, p2,.... In particular, at each time step t, the model receives a 
batch of labeled data, 
Di = {(Ln, Yn) :m=1:M,an~ Pi(XL), Yn = Filatn)} (19.39) 


27 where p;(a) is the unknown input distribution, and fi : Xi — J; is the unknown prediction function. 
28 (We focus on noise-free outputs, for notational simplicity.) The learner is then expected to update its 
29 belief state about the true function f;, and to use its beliefs to make predictions on an independent 
2 test set, 


Di = {(tn, tn) m= 1: NE an ~ ote), Yn = e (19.40) 


22 Depending on how we assume p;(a) and f; evolve over time, and how the test set is defined, we can 
= create a variety of different CL scenarios, as we discuss below. 


— 19.7.1 Domain drift 


38 The problem of domain drift refers to the setting in which p;(a) changes over time (i.e., covariate 
39 shift), but the functional mapping f; : Æ — Y is constant. For example, the vision system of a 
40 self driving car may have to classify cars vs pedestrians under shifting lighting conditions (see e.g., 
41 [Sun+22]). 


To evaluate such a model, we assume f/*' = f, and define p{**'(a) to be the current input 


43 distribution p; (e.g., if it is currently night time, we want the detector to work well on dark images). 


test 


44 Alternatively we can define p{**'(x) to be the union of all the input distributions seen so far, 


test 


45 ptt = Ul_ pg (e.g., we want the detector to work well on dark and light images). This latter 
46 assumption is illustrated in Figure 19.16. 
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Figure 19.17: An illustration of concept drift. 


19.7.2 Concept drift 


The problem of concept drift refers to the setting where the functional mapping fi : Æ —> YV changes 
over time, but the input distribution p(æ) is constant [WK96]. For example, we can imagine a 
setting in which people engage in certain behaviors, and at step t some of these are classified as 
illegal, and at step t’ > t, the definition of what is legal changes, and hence the decision boundary 
changes. This is illustrated in Figure 19.17. 

As another example, we might initially be faced with a sort-by-color task, where red objects go on 
the left and blue objects on the right, and then a sort-by-shape task, where square objects go on the 
left and circular objects go on the right.* We can think of this as a problem where p(y|æ, task) is 
stationary, but the task is unobserved, so p(y|a) changes. 

In the concept drift scenario, we see that the prediction for the same underlying input point x € X 
will change depending on when the prediction is performed. This means that the test distribution 


4. This example is from Mike Mozer. 
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39 also needs to change over time for meaningful identification. Alternatively, we can “tag” each input 
40 with the corresponding time stamp or task id. 

41 
42 
43 
44 A very widely studied form of continual learning focuses on the setting in which new class labels are 
45 “revealed” over time. That is, there is assumed to be a true static prediction function f : & > YV, 
46 but at step t, the learner only sees samples from (4, );), where VY; C VY. For example, 1 may be 
47 
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the space of images, and Yı might be {cats, dogs}, and Və might be {cars, bikes, trucks}. Learning 
to classify with an increasing number of categories is called class incremental learning (see e.g., 
[Mas+20]). This is also called task incremental learning, since each distribution is considered as 
a different task. See Figure 19.18 for an illustration. 

The problem of class incremental learning has been studied under a variety of different assumptions, 
as discussed in [Hsu+18; VT18; FG18; Del+21]. The most common scenarios are shown in Figure 19.19. 
If we assume there are no well defined boundaries between tasks, we have continuous task-agnostic 
learning (see e.g., [SKM21; Zen+21]). If there are well defined boundaries (i.e., discontinuous changes 
of the training distribution), then we can distinguish two subcases. If the boundaries are not known 
during training (similar to detecting distribution shift), we have discrete task-agnostic learning. 
Finally, if the boundaries are given to the training algorithm, we have a task-aware learning 
problem. 

A common experimental setup in the task-aware setting is to define each task to be a different 
version of the MNIST dataset, e.g., with all 10 classes present but with the pixels randomly permuted 
(this is called permuted MNIST) or with a subset of 2 classes present at each step (this is called 
split MNIST).° In the task-aware setting, the task label may or may not be known at test time. 
If it is, the problem is essentially equivalent to multi-task learning (see Section 19.6.1). If it is not, 
the model must predict the task and corresponding class label within that task (which is a standard 
supervised problem with a hierarchical label space); this is commonly done by using a multi-headed 
DNN, with CT outputs, where C is the number of classes, and T is the number of tasks. 

In the multi-headed approach, the number of “heads” is usually specified as input to the algorithm, 
because the softmax imposes a sum-to-one constraint that prevents incremental estimation of the 
output weights in the open-class setting. An alternative approach is to wait until a new class label 
is encountered for the first time, and then train the model with an enlarged output head. This 
requires storing past data from each class, as well as data for the new class (see e.g., [PTD20]). 
Alternatively, we can use generative classifiers where we do not need to worry about “output heads”. If 
we use a “deep” nearest neighbor classifier, with a shared feature extractor (embedding function), the 
main challenge is to efficiently update the stored prototypes for past classes as the feature extractor 
parameters change (see e.g., [DLT21]). If we fit a separate generative model per class (e.g., a VAE, 
as in [VLT21]), then online learning becomes easier, but the method may be less sample efficient. 

At the time of writing, most of the CL literature focuses on the task-aware setting. However, from 
a practical point of view, the assumption that task boundaries are provided at training or test time is 
very unrealistic. For example, consider the problem of training a robot to perform various activities: 
The data just streams in, and the robot must learn what to do, without anyone telling it that it is 
now being given an example from a new task or distribution (see e.g., [Fon+21; Wol+21]). Thus 
future research should focus on the task-agnostic setting, with either discrete or continuous changes. 


19.7.4 Catastrophic forgetting 


In the class incremental learning literature, it is common to train on a sequence of tasks, but to 
test (at each step) on all tasks. In this scenario, there are two main possible failure modes. The 
first possible problem is called “catastrophic forgetting” (see e.g., [Rob95b; Fre99; Kir+17]). This 


5. In the split MNIST setup, for task 1, digits (0,1) get labeled as (0,1), but in task 2, digits (2,3) get labeled as (0,1). 
So the “meaning” of the output label depends on what task we are solving. Thus the output space is really hierarchical, 
namely the cross product of task id and class label. 
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Figure 19.20: Some failure modes in class incremental learning. We train on task 1 (blue) and evaluate on 
tasks 1-3 (blue, orange, yellow); we then train on task 2 and evaluate on tasks 1-3; etc. (a) Catastrophic 
forgetting refers to the phenomenon in which performance on a previous task drops when trained on a new 
task. (b) Too little plasticity (e.g., due to too much regularization) refers to the phenomenon in which only 
the first task is learned. Adapted from Figure 2 of [Had+20]. 
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Figure 19.21: What success looks like for class incremental learning. We train on task 1 (blue) and evaluate 
on tasks 1-3 (blue, orange, yellow); we then train on task 2 and evaluate on tasks 1-3; etc. (a) No forgetting 
refers to the phenomenon in which performance on previous tasks does not degrade over time. (b) Forward 


__ transfer refers to the phenomenon in which training on past tasks improves performance on future tasks beyond 
— what would have been obtained by training from scratch. (c) Backwards transfer refers to the phenomenon in 
= which training on future tasks improves performance on past tasks beyond what would have been obtained by 
= training from scratch. Adapted from Figure 2 of [Had+20]. 


35 refers refers to the phenomenon in which performance on a previous task drops when trained on a 
36 new task (see Figure 19.20(a)). Another possible problem is that only the first task is learned, and 
37 the model does not adapt to new tasks (see Figure 19.20(b)). 


If we avoid these problems, we should expect to see the performance profile in Figure 19.21(a), 


39 where performance of incremental training is equal to training on each task separately. However, we 
40 might hope to do better by virtue of the fact that we are training on multiple tasks, which are often 
41 assumed to be related. In particular, we might hope to see forward transfer, in which training on 
42 past tasks improves performance on future tasks beyond what would have been obtained by training 
43 from scratch (see Figure 19.21(b)). Additionally, we might hope to see backwards transfer, in 
44 which training on future tasks improves performance on past tasks (see Figure 19.21(c)). 


We can quantify the degree of transfer as follows, following [LPR17]. If Ri; is the performance on 


46 task j after it was trained on task i, Ri is the performance on task j when trained just on j, and 
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there are T tasks, then the amount of forward transfer is 
iS 
FWT = 7 XO Rij- RP (19.41) 
j=1 
and the amount of backwards transfer is 
ice 
BWT = = So Rri- Rjj (19.42) 
j=1 


There are many methods that have been devised to overcome the problem of catastrophic forgetting, 
but we can group them into three main types. The first is regularization methods, which add a loss 
to preserve information that is relevant to old tasks. (For example, online Bayesian inference is of this 
type, since the posterior for the parameters is derived from the new data and the past prior; see e.g., 
the elastic weight consolidation method discussed in Section 17.5.1, or the variational continual 
learning method discussed in Supplementary Section 10.2). The second is memory methods, 
which rely on some kind of experience replay or rehearsal of past data (see e.g.,|Hen+21]), 
or some kind of generative model of past data. The third is architectural methods, that add 
capacity to the network whenever a task boundary is encountered, such as a new class label (see e.g., 
[Rus+16]). 

Of course, these techniques can be combined. For example, we can create a semi-parametric model, 
in which we store some past data (exemplars) while also learning parameters online in a Bayesian 
(regularized) way (see e.g., [Kur+20]). The “right” method depends, as usual, on what inductive bias 
you want to use, and want your computational budget is in terms of time and memory. 


19.7.5 Online learning 


The problem of online learning is similar to continual learning, except the loss metric is different, 
and we usually assume that learning and evaluation occur at each step. More precisely, we assume 
the data generating distribution, pj(x,y) = p(ax|@,)p(y|x, wz), evolves over time, as shown in 
Figure 19.22. At each step t nature generates a data sample, (a, y+) ~ př. The agent sees x; and is 
asked to predict y+ by computing the posterior predictive distribution 


Pelt-1 = P(Y|@t, Dit-1) (19.43) 
where Dy4-1 = {(@5,Ys) :$ =1:t—1} is all past data. It then incurs a loss of 
Li = L(Pit-1: Ye) (19.44) 


For classification problems, we often use 0-1 loss, L; = I (ĝe 4 y+). In this case the optimal action is 
to predict the most probable label, # = argmax, p(y = c|æt, P14-1).. 

In contrast to the continual learning scenarios studied above, the loss incurred at each step is what 
matters, rather than loss on a fixed test set. That is, we want to minimize 


T 
Ley L (19.45) 
t=1 
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Figure 19.22: Online learning illustrated as an influence diagram (Section 34.2). Here ĝe = 
argmax,, plylæ:, Dı::+—1) is the action (MAP predicted output) at time t, and Li = l(yt, ĝt) is the cor- 
responding loss (utility) function. The data at time t is assumed to come from a model of the form 
p(xt, yt) = p(xe|b,)p(ye|ez, We). The parameters of the world model can change arbitrarily over time. 


Since it is hard to interpret this number, it is common to compare it to the optimal value one could 
have obtained in hindsight. This yields a quantity called the regret: 


T 
regret = XO Kbar, ye) — Pir we) (19.46) 


t=1 


30 where py¢—-1 = p(y|@:,Dix—1) is the online prediction, and pyr = p(y|%:,P1.r) is the optimal 
31 estimate at the end of training. It is possible to convert bounds on regret, which are backwards 
32 looking, into bounds on risk (i.e., expected future loss), which is forwards looking. See [HT 15] for 


33 details. 


Online learning is very useful for decision and control problems, such as multi-armed bandits 


35 (Section 34.4) and reinforcement learning (see Chapter 35), where the agent “lives forever”, and where 
36 there is no fixed training phase followed by a test phase. (See e.g., Section 17.5 where we discuss 
37 online Bayesian inference for neural networks.) 


The previous continual learning scenarios can be derived as special cases of online learning, by 


39 defining a suitable sequence of distributions, and by requiring the agent to either train or test at 
40 each step on a suitable minibatch of data. (We leave the details of this mapping as an exercise to the 
41 reader.) 


44 19.8 Adversarial examples 


— This section is coauthored with Justin Gilmer. 
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“gibbon” 


% confidence 99.3% confidence 


Figure 19.23: Example of an adversarial attack on an image classifier. Left column: original image which is 
correctly classified. Middle column: small amount of structured noise which is added to the input (magnitude 
of noise is magnified by 10x). Right column: new image, which is confidently misclassified as a “gibbon”, 
even though it looks just like the original “panda” image. Here e = 0.007 From Figure 1 of [GSS15]. Used 
with kind permission of Ian Goodfellow. 


In Section 19.2, we discussed what happens to a predictive model when the input distribution 
shifts for some reason. In this section, we consider the case where an adversary deliberately chooses 
inputs to minimize the performance of a predictive model. That is, suppose an input æ is classified 
as belonging to class c. We then choose a new input Zaqy which minimizes the probability of this 
label, subject to the constraint that £aav is “perceptually similar” to the original input æ. This gives 
rise to the following objective: 


Ladv = argmin log p(y = c|zx’) (19.47) 
w’cA(a) 


where A(z) is the set of images that are “similar” to æ (we discuss different notions of similarity 
below). 

Equation (19.47) is an example of an adversarial attack. We illustrate this in Figure 19.23. The 
input image æ is on the left, and is predicted to be a panda with probability 57%. By adding a tiny 
amount of carefully chosen noise (shown in the middle) to the input, we generate the adversarial 
image Zaqy on the right: this “looks like” the input, but is now classified as a gibbon with probability 
99%. 

The ability to create adversarial images was first noted in [Sze+14]. It is suprisingly easy to create 
such examples, which seems paradoxical, given the fact that modern classifiers seem to work so well 
on normal inputs, and the perturbed images “look” the same to humans. We explain this paradox in 
Section 19.8.5. 

The existence of adversarial images also raises security concerns. For example, [Sha+16] showed 
they could force a face recognition system to misclassify person A as person B, merely by asking 
person A to wear a pair of sunglasses with a special pattern on them, and [Eyk+18] show that is 
possible to attach small “adverarial stickers” to traffic signs to classify stop signs as speed limit 
signs. 

Below we briefly discuss how to create adversarial attacks, why they occur, and how we can try to 
defend against them. We focus on the case of deep neural nets for images, although it is important 
to note that many other kinds of models (including logistic regression and generative models) can 
also suffer from adversarial attacks. Furthermore, this is not restricted to the image domain, but 
occurs with many kinds of high dimensional inputs. For example, [Li+19] contains an audio attack 
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and [Dal+04; Jia+19] contains a text attack. More details on adversarial examples can be found in 
e.g., [Wiy+19; Yua+19]. 


19.8.1 Whitebox (gradient-based) attacks 


To create an adversarial example, we must find a “small” perturbation 6 to add to the input x to 
create Lady = £ + 6 so that f(£aav) = y’, where f() is the classifier, and y’ is the label we want to 
force the system to output. This is known as a targeted attack. Alternatively, we may just want 
to find a perturbation that causes the current predicted label to change from its current value to any 
other value, so that f(a + ô) # f(x), which is known as untargeted attack. 

In general, we define the objective for the adversary as maximizing the following loss: 


Laav = argmax L(x’, y; 0) (19.48) 
aw’ EA(a) 
where y is the true label. For the untargeted case, we can define L(x’, y; 0) = — log p(y|x’), so we 


minimize the probability of the true label; and for the targeted case, we can define L(x’, y; 0) = 
log p(y’|a’), where we maximize the probability of the desired label y’ Æ y. 

To define what we mean by “small” perturbation, we impose the constraint that aay E A(£), 
which is the set of “perceptually similar” images to the input a. Most of the literature has focused 
on a simplistic setting in which the adversary is restricted to making bounded lp perturbations of a 
clean input æ, that is 


A(x) = {x : ||a’ — ||, < €} (19.49) 


Typically people assume p = 1 or p = 0. We will discuss more realistic threat models in Section 19.8.3. 
In this section, we assume that the attacker knows the model parameters 9; this is called a 
whitebox attack, and lets us use gradient based optimization methods. We relax this assumption 
in Section 19.8.2.) 
To solve the optimization problem in Equation (19.48), we can use any kind of constrained 


39 optimization method. In [Sze+14] they used bound-constrained BFGS. [GSS15] proposed the more 


efficient fast gradient sign (FGS) method, which performs iterative updates of the form 


Lt4+1 = Tt + Ô (19.50) 
6, = € sign(V log p(y’ |x, 0)læ,) (19.51) 


38 where e > 0 is a small learning rate. (Note that this gradient is with respect to the input pixels, not 
39 the model parameters.) Figure 19.23 gives an example of this process. 


More recently, [Mad+18] proposed the more powerful projected gradient descent (PGD) 


41 attack; this can be thought of as an iterated version of FGS. There is no “best” variant of PGD for 
42 solving 19.48. Instead, what matters more is the implementation details, e.g. how many steps are 
43 used, the step size, and the exact form of the loss. To avoid local minima, we may use random restarts, 
44 choosing random points in the constraint space A to initialize the optimization. The algorithm 
45 should be carefully tuned to the specific problem, and the loss should be monitored to check for 
46 optimization issues. For best practices, see [Car+19]. 
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Figure 19.24: Images that look like random noise but which cause the CNN to confidently predict a specific 
class. From Figure 1 of [NYC15]. Used with kind permission of Jeff Clune. 
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Figure 19.25: Synthetic images that cause the CNN to confidently predict a specific class. From Figure 1 of 
[NYC15]. Used with kind permission of Jeff Clune. 


19.8.2 Blackbox (gradient-free) attacks 


In this section, we no longer assume that the adversary knows the parameters 0 of the predictive model 
f. This is known as a black box attack. In such cases, we must use derivative-free optimization 
(DFO) methods (see Section 6.9). 

Evolutionary algorithms (EA) are one class of DFO solvers. These were used in [NYC15] to create 
blackbox attacks. Figure 19.24 shows some images that were generated by applying an EA to a 
random noise image. These are known as fooling images, as opposed to adversarial images, since 
they are not visually realistic. Figure 19.25 shows some fooling images that were generated by 
applying EA to the parameters of a compositional pattern-producing network (CPPN) [Sta07].° By 
suitably perturbing the CPPN parameters, it is possible to generate structured images with high 
fitness (classifier score), but which do not look like natural images [Aue12]. 

In [SVK19], they used differential evolution to attack images by modifying a single pixel. This is 
equivalent to bounding the Zọ norm of the perturbation, so that ||@aay — v||o = 1. 

In [Pap+17], they learned a differentiable surrogate model of the blackbox, by just querying its 
predictions y for different inputs æ. They then used gradient-based methods to generate adversarial 
attacks on their surrogate model, and then showed that these attacks transferred to the real model. 
In this way, they were able to attack various the image classification APIs of various cloud service 


6. A CPPN is a set of elementary functions (such as linear, sine, sigmoid, and Gaussian) which can be composed in 
order to specify the mapping from each coordinate to the desired color value. CPPN was originally developed as a way 
to encode abstract properties such as symmetry and repetition, which are often seen during biological development. 
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Figure 19.26: An adversarially modified image to evade spam detectors. The image is constructed from 
scratch, and does not involve applying a small perturbation to any given image. This is an illustrative example 
of how large the space of possible adversarial inputs A can be when the attacker has full control over the input. 
From [Big+11]. Used with kind permission of Battista Biggio. 


21 providers, including Google, Amazon and MetaMind. 


19.8.3 Real world adversarial attacks 


Typically, the space of possible adversarial inputs A can be quite large, and will be difficult to exactly 
define mathematically as it will depend on semantics of the input based on the attacker’s goals 
[BR18]. (The set of variations A that we want the model to be invariant to is called the threat 
model.) 

Consider for example of the content constrained threat model discussed in [Gil+18a]. One instance 


30 of this threat model involves image spam, where the attacker wishes to upload an image attachment 
31 in an email that will not be classified as spam by a detection model. In this case A is incredibly 
32 large as it consists of all possible images which contain some semantic concept the attacker wishes to 
33 upload (in this case an advertisement). To explore A, spammers can utilize different fonts, word 


orientations or add random objects to the background as is the case of the adversarial example in 


35 Figure 19.26 (see [Big+11] for more examples). Of course, optimization based methods may still be 
36 used here to explore parts of A. However, in practice it may be preferable to design an adversarial 
37 input by hand as this can be significantly easier to execute with only limited-query black-box access 
38 to the underlying classifier. 


19.8.4 Defenses based on robust optimization 


42 As discussed in Section 19.8.3, securing a system against adversarial inputs in more general threat 
43 models seems extraordinarily difficult, due to the vast space of possible adversarial inputs A. However, 


there is a line of research focused on producing models which are invariant to perturbations within 


45 a small constraint set A(a), with a focus on /,-robustness where A(x) = {x : |æ — a’ ||, < €}. 
46 Although solving this toy threat model has little application to security settings, enforcing smoothness 
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19.8. ADVERSARIAL EXAMPLES 


priors have in some cases improved robustness to random image corruptions [SHS], lead to models 
which transfer better [Sal+20], and can bias models towards different features in the data [Yin+ 19a]. 

Perhaps the most straightforward method for improving [,-robustness is to directly optimize for 
it through robust optimization [BTEGN09], also known as adversarial training [GSS15]. We 
define the adversarial risk to be 


min E(e,y)~p(2,y) Rives L(x’, y; 6) (19.52) 


The min max formulation in equation 19.52 poses unique challenges from an optimization perspective— 
it requires solving both the non-concave inner maximazation and the non-convex outer minimization 
problems. Even worse, the inner max is NP-hard to solve in general [Kat-+17]. However, in practice it 
may be sufficient to compute the gradient of the outer objective VoL(Xaav, y,;9) at an approximately 
maximal point in the inner problem @aqy © argmax,, L(£aav, Y; 9) [Mad+18]. Currently, best practice 
is to approximate the inner problem using a few steps of PGD. 

Other methods seek to certify that a model is robust within a given region A(x). One method 
for certification uses randomized smoothing [CRK19]—a technique for converting a model robust 
to random noise into a model which is provably robust to bounded worst-case perturbations in 
the l2-metric. Another class of methods applies specifically for networks with ReLU activations, 
leveraging the property that the model is locally linear, and that certifying in region defined by linear 
constraints reduces to solving a series of linear programs, for which standard solvers can be applied 
[WK18]. 


19.8.5 Why models have adversarial examples 


The existence of adversarial inputs is paradoxical, since modern classifiers seem to do so well on 
normal inputs. However, the existence of adversarial examples is a natural consequence of the 
general lack of robustness to distribution shift discussed in Section 19.2. To see this, suppose a 
model’s accuracy drops on some shifted distribution of inputs p,-(a) that differs from the training 
distribution p;,(a); in this case, the model will necessarily be vunerable to an adversarial attack: 
if errors exist, there must be a nearest such error. Furthermore, if the input distribution is high 
dimensional, then we should expect the nearest error to be significantly closer than errors which are 
sampled randomly from some out-of-distribution pre(æ). 

A cartoon illustration of what is going on is shown in Figure 19.27a, where æo is the clean input 
image, B is an image corrupted by Gaussian noise, and A is an adversarial image. If we assume a 
linear decision boundary, then the error set E is a half space a certain distance from ao. We can 
relate the distance to the decision boundary d(ao, Æ) with the error rate in noise at some input £o, 
denoted by u = Ps~N(0,0r) [£o +ô € EJ]. With a linear decision boundary the relationship between 
these two quantities is determined by 


d(ap, E) = =o! (u) (19.53) 


where ®~! denotes the inverse cdf of the gaussian distribution. When the input dimension is large, 
this distance will be significantly smaller than the distance to a randomly sampled noisy image 
xo + 6 for 6 ~ N(0,c1), as the noise term will with high propbability have norm ||d||x ~ ovd. As a 
concrete example consider the ImageNet dataset, where d = 224 x 224 x 3 and suppose we set o = .2. 
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Figure 19.27: (a) When the input dimension n is large and the decision boundary is locally linear, even a 
small error rate in random noise will imply the existence of small adversarial perturbations. Here, d(ao, E) 
denotes the distance from a clean input zo to an adversarial example (A) while the distance from xo to a 
random sample N(0;07I (B) will be approximately oyn. As n—+ 00 the ratio of d(ao, A) to d(ao, B) goes 
to 0. (b) A 2d slice of the Inception V3 decision boundary through three points: a clean image (black), an 
adversarial example (red), and an error in random noise (blue). The adversarial example and the error 
in noise lie in the same region of the error set which is misclassified as “miniature poodle”, which closely 
resembles a halfspace as in Figure (a). Used with kind permission of Justin Gilmer. 


Then if the error rate in noise is just u = .01, equation 19.53 will imply that d(ao, E) = .5. Thus the 
distance to an adversarial example will be more than 100 times closer than the distance to a typical 
noisy images, which will be ovd ~ 77.6. This phenomenon of small volume error sets being close 
to most points in a data distribution p(x) is called concentration of measure, and is a property 
common among many high dimensional data distributions [MDM19; Gil+18b]. 

In summary, although the existence of adversarial examples is often discussed as an unexpected 
phenomenon, there is nothing special about the existence of worst-case errors for ML classifiers—they 
will always exist as long as errors exist. 
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PART IV 


Generation 


20 Generative models: an overview 


20.1 Introduction 


A generative model is a joint probability distribution p(x), for x € X. In some cases, the model 
may be conditioned on inputs or covariates c € C, which gives rise to a conditional generative 
model of the form p(z|c). 

There are many kinds of generative model. We give a brief summary in Section 20.2, and go into 
more detail in subsequent chapters. See also [Tom22] for a recent book on this topic that goes into 
more depth. 


20.2 Types of generative model 


There are many kinds of generative model, some of which we list in Table 20.1. At a high level, 
we can distinguish between deep generative models (DGM) — which use deep neural network 


to learn a complex mapping from a single latent vector z to the observed data a — and more 
“classical” probabilistic graphical models (PGM), that map a set of interconnected latent variables 
Z1,-.-,2z, to the observed variables 2,,...,a@p using simpler, often linear, mappings. Of course, 


many hybrids are possible. For example, PGMs can use neural networks, and DGMs can use 
structured state spaces. We discuss PGMs in general terms in Chapter 4, and give examples in 
Chapter 28, Chapter 29, Chapter 30. In this part of the book, we mostly focus on DGMs. 

The main kinds of DGM are: variational autoencoders (VAE), autoregressive (AR) models, 
normalizing flows, diffusion models, energy based models (EBM), and generative adver- 
sarial networks (GAN). We can categorize these models in terms of the following criteria (see 
Figure 20.1 for a visual summary): 


e Density: does the model support pointwise evaluation of the probability density function p(x), 
and if so, is this fast or slow, exact, approximate or a bound, etc? For implicit models, such 
as GANSs, there is no well-defined density p(x). For other models, we can only compute a lower 
bound on the density (VAEs), or an approximation to the density (EBMs, UPGMs). 


e Sampling: does the model support generating new samples, x ~ p(x), and if so, is this fast or slow, 
exact or approximate? Directed PGMs, VAEs and GANs all support fast sampling. However, 
undirected PGMs, EBMs, AR, diffusion and flows are slow for sampling. 


e Training: what kind of method is used for parameter estimation? For some models (such as AR, 
flows and directed PGMs), we can perform exact maximum likelihood estimation (MLE), although 
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Model Chapter Density Sampling Training Latents Architecture 

PGM-D Section 4.2 Exact, fast Fast MLE Optional Sparse DAG 

PGM-U Section 4.3 Approx, slow Slow MLE-A Optional Sparse graph 

VAE Chapter 21 LB, fast Fast MLE-LB RY Encoder-Decoder 

AR Chapter 22 Exact, fast Slow MLE None Sequential 

Flows Chapter 23 Exact, slow/fast Slow MLE RP Invertible 

EBM Chapter 24 Approx, slow Slow MLE-A Optional Discriminative 

Diffusion Chapter 25 LB Slow MLE-LB R? Encoder-Decoder 

GAN Chapter 26 NA Fast Min-max R* Generator-Discriminator 


Table 20.1: Characteristics of common kinds of generative model. Here D is the dimensionality of the observed 
x, and L is the dimensionality of the latent z, if present. (We usually assume L & D, although overcomplete 
representations can have L >> D.) Abbreviations: Appror = approximate, AR = autoregressive, EBM = 
Energy Based Model, GAN = generative adversarial network, MLE = maximum likelihood estimation, MLE-A 
= MLE (Approximate), MLE-LB = MLE (Lower Bound), NA = not available, PGM = probabilistic graphical 
model, PGM-D = directed PGM, PGM-U = undirected PGM, VAE = variational autoencoder. 


EBM: 

Approximate x ® 
Maximum 

likelihood 


Generator 


G(z) 


GAN: 
Adversarial Xs 
training 


VAE: Maximize 
variational lower x 
bound 


Encoder 
96 (21x), 


Decoder 
Po(x|Z) 


— 


Flow-based Model: 
Invertible transform of 
distributions 


Diffusion Model: 
Gradually add 
Gaussian noise and 
then reverse 


Autoregressive 
model: Learn 
conditional of each 
variable given past 


Figure 20.1: Summary of various kinds of deep generative model. Here x is the observed data, z is the latent 
code, and x’ is a sample from the model. AR models do not have a latent code z. For diffusion models and 


* flow models, the size of z is the same as x. For AR models, x“ is the d’th dimension of æ. R represents 


real-valued output, 0/1 represents binary output. Adapted from Figure 1 of [Wen21]. 
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20.3. GOALS OF GENERATIVE MODELING 


Figure 20.2: Synthetic faces from a score-based generative model (Section 24.3.5). From Figure 12 of 
[Son+21]. Used with kind permission of Yang Song. 


the objective is usually non-convex, so we can only reach a local optimum. For other models, we 
cannot tractably compute the likelihood. In the case of VAEs, we maximize a lower bound on the 
likelihood; in the case of EBMs and UGMs, we maximize an approximation to the likelihood. For 
GANs we have to use min-max training, which can be unstable, and there is no clear objective 
function to monitor. 


e Latents: does the model use use a latent vector z to generate x or not, and if so, is it the same 
size as x or is it a potentially compressed representation? For example, AR models do not use 
latents; flows and diffusion use latents, but they are not compressed.! Graphical models, including 
EBMs, may or may not use latents. 


e Architecture: what kind of neural network should we use, and are there restrictions? For flows, 
we are restricted to using invertible neural networks where each layer has a tractable Jacobian. 
For EBMs, we can use any model we like. The other models have different restrictions. 


20.3 Goals of generative modeling 


There are several different kinds of tasks that we can use generative models for, as we discuss below. 


20.3.1 Generating data 


One of the main goals of generative models is to generate (create) new data samples. For example, if 
we fit a model p(x) to images of faces, we can sample new faces from it, as illustrated in Figure 20.2.7 


1. Flow models define a latent vector z that has the same size as x, although the internal deterministic computation 
may use vectors that are larger or smaller than the input (see e.g., the DenseFlow paper [GGS21]). 

2. These images were made with a technique called score-based generative modeling (Section 24.3.5), although similar 
results can be obtained using many other techniques. See for example https: //this-person-does-not-exist.com/en 
which shows results from a GAN model (Chapter 26). 
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~ED 


: E (c) A cute sloth holding a small trea- 
(a) Teddy bears swimming at the (b) A cute corgi lives in a house sure chest. A bright golden glow is 
Olympics 400m Butterfly event. made out of sushi. coming from the chest. 


Figure 20.8: Some 1024 x 1024 images generated from text prompts by the Imagen diffusion model (Sec- 
tion 25.2.8). From Figure 1 of [Sah+22]. Used with kind permission of William Chan. 


(b) (c) 


Figure 20.4: Converting a gray-scale image (left) to a colorized one (middle) using a conditional GAN 


= (Chapter 26). Right column shows the original image. (Photo depicts author’s wife, Margaret Murphy, photo 
2 taken July 2022.) Generated by https: //github. com/ jantic/DeOldify. 


Similar methods can be used to create samples of text, audio, etc. When this technology is abused 
to make fake content, they are called deep fakes. For a review of this topic, see e.g., [Ngu+19]. 

To control what is generated, it is useful to use a conditional generative model of the form 
p(a|c). Here are some examples: 


e c= text prompt, x = image. This is a text-to-image model (see Figure 20.3 and Figure 22.8 
for examples). 


e c= image, x = text. This is an image-to-text model, which is useful for image captioning. 


e c= image, x = image. This is an image-to-image model (see Figure 20.4 and Figure 25.4 for 
examples). 


e c = sequence of sounds, æ = sequence of words. This is a speech-to-text model, which is useful 
for automatic speech recognition (ASR). 
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Figure 20.5: A nonparametric (Parzen) density estimator in 1d estimated from 6 data points, denoted by z. 
Top row: uniform kernel. Bottom row: Gaussian kernel. Left column: bandwidth parameter h = 1. Right 
column: bandwidth parameter h = 2. Adapted from http: // en. wikipedia. org/wiki/Kernel_ density_ 
estimation. Generated by parzen_window_ demo.ipynb. 


e c = sequence of English words, x = sequence of French words. This is a sequence-to-sequence 
model, which is useful for machine translation. 


e c = initial prompt, æ = continuation of the text. This is another sequence-to-sequence model, 
which is useful for automatic text generation (see Figure 22.4 for an example). 


Note that, in the conditional case, we sometimes denote the inputs by æ and the outputs by y. In 
this case the model has the familiar form p(y|a). In the special case that y denotes a low dimensional 
quantity, such as a integer class label, y € {1,...,C}, we get a predictive (discriminative) model. 
The main difference beween a discriminative model and a conditional generative model is this: in a 
discriminative model, we assume there is one correct output, whereas in a conditional generative 
model, we assume there may be multiple correct outputs. This makes it harder to evaluate generative 
models, as we discuss in Section 20.4. 


20.3.2 Density estimation 


The task of density estimation refers to evaluating the probablity of an observed data vector, 
i.e., computing p(x). This can be useful for outlier detection (Section 19.3.2), data compression 
(Section 5.4), generative classifiers, model comparison, etc. 
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Missing values 


Variables replaced by means 


Data sample A B c A B c 
i 6 6 NA 2 6 7.5 

2 NA 6 o 9 6 o 
3 NA 6 N ; 9 6 45 

4 10 10 10 i 7 10 10 10 

5 10 10 10 10 10 10 

6 10 10 10 10 10 10 
Average 9 8 7.5 9 8 7.5 


Figure 20.6: Missing data imputation using the mean of each column. 


A simple approach to this problem, which works in low dimensions, is to use kernel density 


estimation or KDE, which has the form 


p(2|D) = (£ — Lp) (20.1) 


a 


=° Here D = {vz1,..., £y} is the data, and Kp is a density kernel with bandwidth h, which is a 
2l function K : R — R4 such that f K(x)dx = 1 and f xK(x)dxz = 0. We give a 1d example of this in 
£ Figure 20.5: in the top row, we use a uniform (boxcar) kernel, and in the bottomr row we use a 


= Gaussian kernel. 


In higher dimensions, KDE suffers from the curse of dimensionality (see e.g., |[AHK01]), and 


= we need to use parametric density models pg(ax) of some kind. 


= 20.3.3 Imputation 


36 The task of imputation refers to “filling in” missing values of a data vector or data matrix. For 
37 example, suppose X is an N x D matrix of data (think of a spreadsheet) in which some entries, call 
38 them Xm, may be missing, while the rest, Xo, are observed. A simple way to fill in the missing 
39 data is to use the mean value of each feature, E [ag]; this is called mean value imputation, and is 
40 illustrated in Figure 20.6. However, this ignores dependencies between the variables within each row, 
41 and does not return any measure of uncertainty. 


We can generalize this by fitting a generative model to the observed data, p(X,), and then 


43 computing samples from p(Xm|Xo). This is called multiple imputation. A generative model can 
44 be used to fill in more complex data types, such as in-painting occluded pixels in an image (see 
45 Figure 25.4). 


See Section 21.3.4 for a more general discussion of missing data. 
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Sachs Multiparameter Flow Cytometry Dataset 
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Figure 20.7: (a) A design matrix consisting of 5400 data points (rows) measuring the state (using flow 
cytometry) of 11 proteins (columns) under different experimental conditions. The data has been discretized 
into 3 states: low (black), medium (grey) and high (white). Some proteins were explicitly controlled using 
activating or inhibiting chemicals. (b) A directed graphical model representing dependencies between various 
proteins (blue circles) and various experimental interventions (pink ovals), which was inferred from this data. 
We plot all edges for which p(Gi; = 1|D) > 0.5. Dotted edges are believed to exist in nature but were not 
discovered by the algorithm (1 false negative). Solid edges are true positives. The light colored edges represent 
the effects of intervention. From Figure 6d of [EM07]. 


20.3.4 Structure discovery 


Some kinds of generative models have latent variables z, which are assumed to be the “causes” 
that generated the observed data x. We can use Bayes rule to invert the model to compute 
p(z|x) x p(z)p(a|z). This can be useful for discovering latent, low-dimensional patterns in the data. 

For example, suppose we perturb various proteins in a cell and measure the resulting phosphorylation 
state using a technique known as flow cytometry, as in [Sac+05]. An example of such a dataset is 
shown in Figure 20.7(a). Each row represents a data sample £n ~ p(:|an, z), where x € R!! isa 
vector of outputs (phosphorylations), a € {0,1}° is a vector of input actions (perturbations) and z is 
the unknown cellular signaling network structure. We can infer the graph structure p(z|D) using 
graphical model structure learning techhniques (see Section 30.3). In particular, we can use the 
dynamic programming method described in [EM07] to get the result is shown in Figure 20.7(b). Here 
we plot the median graph, which includes all edges for which p(zj; = 1|D) > 0.5. (For a more recent 
approach to this problem, see e.g., [Bro+20b].) 


20.3.5 Latent space interpolation 


One of the most interesting abilities of certain latent variable models is the ability to generate 
samples that have certain desired properties by interpolating between existing data points in latent 
space. To explain how this works, let xı and x2 be two inputs (e.g. images), and let zı = e(a 1) and 
Z2 = e(x2) be their latent encodings. (The method used for computing these will depend on the 
type of model; we discuss the details in later chapters.) We can regard zı and z2 as two “anchors” in 
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Figure 20.8: Interpolation between two MNIST images in the latent space of a B-VAE (with B = 0.5). 
Generated by mnist_vae_ae_comparison.ipynb. 
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>, Figure 20.9: Interpolation between two CelebA images in the latent space of a B-VAE (with 8 = 0.5). 


~- Generated by celeba_vae_ae_ comparison. ipynb. 


latent space. We can now generate new images that interpolate between these points by computing 
z =z + (1 — A)z2, where 0 < À < 1, and then decoding by computing æ’ = d(z), where d() is the 
decoder. This is called latent space interpolation, and will generate data that combines semantic 
features from both zı and x2. (The justification for taking a linear interpolation is that the learned 
manifold often has approximately zero curvature, as shown in [SKTF18]. However, sometimes it is 
better to use nonlinear interpolation [Whil6; MB21; Fad+20].) 

We can see an example of this process in Figure 20.8, where we use a 6-VAE model (Section 21.3.1) 


39 fit to the MNIST dataset. We see that the model is able to produce plausible interpolations between 


the digit 7 and the digit 2. As a more interesting example, we can fit a 6-VAE to the the CelebA 
dataset [Liu+15].° The results are shown in Figure 20.9, and look reasonable. (We can get much 
better quality if we use a larger model trained on more data for a longer amount of time.) 

It is also possible to perform interpolation in the latent space of text models, as illustrated in 
Figure 21.10. 


39 20.3.6 Latent space arithmetic 


~~ In some cases, we can go beyond interpolation, and can perform latent space arithmetic, in which 


— we can increase or decrease the amount of a desired “semantic factor of variation”. This was first 


~~ shown in the word2vec model [Mik-+13], but it also is possible in other latent variable models. For 


~ example, consider our VAE model fit to CelebA dataset, which has faces of celebrities and some 


=2 3. CelebA contains about 200k images of famous celebrities. The images are also annotated with 40 attributes. We 


46 reduce the resolution of the images to 64x64, as is conventional. See Section 21.2.6 for the details of the model. 
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Figure 20.10: Arithmetic in the latent space of a B-VAE (with B = 0.5). The first column is an input 
image, with embedding z. Subsequent columns show the decoding of z + sA, where s € {—2, —1,0,1,2} and 


A =z —Z is the difference in the average embeddings of images with or without a certain attribute (here, 
wearing sunglasses). Generated by celeba_vae_ae_comparison.ipynb. 


corresponding attributes. Let Xf be a set of images which have attribute i, and X; be a set of 
images which do not have this attribute. Let Zi and Z; be the corresponding embeddings, and 2. 
and Z7 be the average of these embeddings. We define the offset vector as A; = Z} — Z}. If we 
add some positive multiple of A; to a new point z, we increase the amount of the attribute 7; if we 
subtract some multiple of A;, we decrease the amount of the attribute i. [Whi16]. 

We give an example of this in Figure 20.10. We consider the attribute of wearing sunglasses. The 
jth reconstruction is computed using ĉj; = d(z + s;A), where z = e(æ) is the encoding of the 
original image, and s; is a scale factor. When sj > 0 we add sunglasses to the face. When s; < 0 we 
remove sunglasses; but this also has the side effect of making the face look younger and more female, 
possibly a result of dataset bias. 


20.3.7 Generative design 


Another interesting use case for (deep) generative models is generative design, in which we use 
the model to generate candidate objects, such as molecules, which have desired properties (see 
e.g., [RNA22]). One approach is to fit a VAE to unlabeled samples, and then to perform Bayesian 
optimization (Section 6.8) in its latent space, as discussed in Section 21.3.6.2. 


20.3.8 Model-based reinforcement learning 


We discuss reinforcement learning (RL) in Chapter 35. The main success stories of RL to date have 
been in computer games, where simulators exist and data is abundant. However, in other areas, such 
as robotics, data is expensive to acquire. In this case, it can be useful to learn a generative “world 
model’, so the agent can do planning and learning “in it’s head”. See Section 35.4 for more details. 


20.3.9 Representation learning 


Representation learning refers to learning (possibly uninterpretable) latent factors z that generate 
the observed data x. The primary goal is for these features to be used in “downstream” supervised 
tasks. This is discussed in Chapter 32. 
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20.3.10 Data compression 


Models which can assign high probability to frequently occuring data vectors (e.g., images, sentences), 
and low probability to rare vectors, can be used for data compression, since we can assign shorter 
codes to the more common items. Indeed, the optimal coding length for a vector x from some 
stochastic source p(x) is I(x) = — log p(x), as proved by Shannon. See Section 5.4 for details. 


20.4 Evaluating generative models 
This section is coauthored with Mihaela Rosca, Shakir Mohamed and Balaji Lakshminarayanan. 
Evaluating generative models requires metrics which capture 


e sample quality - are samples generated by the model a part of the data distribution? 


e sample diversity - are samples from the model distribution capturing all modes of the data 
distribution? 


e generalization - is the model generalizing beyond the training data? 


There is no known metric which meets all these requirements, but various metrics have been proposed 
to capture different aspects of the learned distribution, some of which we discuss below. 


20.4.1 Likelihood-based evaluation 
A standard way to measure how close a model q is to a true distribution p is in terms of the KL 
divergence (Section 5.1): 
= p(w) _ 
Dx (p || q) = | p(#) log a H (p) + Hee (p, q) (20.2) 


— where H (p) is a constant, and Hee (p, q) is the cross entropy. If we approximate p(x) by the empirical 
— distribution, we can evaluate the cross entropy in terms of the empirical negative log likelihood 


= on the dataset: 


46 of x. (To compute this metric, recall that logy L = 


N 
1 
NLL =- > log q(an) (20.3) 


Usually we care about negative log likelihood on a held-out test set.“ 


20.4.1.1 Computing log-likelihood 


= For models of discrete data, such as language models, it is easy to compute the (negative) log 
= likelihood. However, it is common to measure performance using a quantity called perplexity, which 
23 is defined as 2, where H = NLL is the cross entropy or negative log likelihood. 


45 4. In some applications, we report bits per dimension, which is the NLL using log base 2, divided by the dimensionality 


loge L ) 
loge 2° 
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20.4. EVALUATING GENERATIVE MODELS 


For image and audio models, one complication is that the model is usually a continuous distribution 
p(x) > 0 but the data is usually discrete (e.g., x € {0,...,255}” if we use one byte per pixel). 
Consequently the average log likelihood can be arbitrary large, since the pdf can be bigger than 
1. To avoid this it is standard pratice to use uniform dequantization [TOB16], in which we add 
uniform random noise to the discrete data, and then treat it as continuous-valued data. This gives a 
lower bound on the average log likelihood of the discrete model on the original data. 

To see this, let z be a continuous latent variable, and æ be a vector of binary observations computed 
by rounding, so p(a|z) = 6(a@ — round(z)), computed elementwise. We have p(x) = f p(x|z)p(z)dz. 
Let g(z|x) be a probabilistic inverse of x, that is, it has support only on values where p(a|z) = 1. In 
this case, Jensen’s inequality gives 


log p(x) > Eq(z\a) [log p(a|z) + log p(z) — log a(z|æ)] (20.4) 
= Eq(z\x) [log p(z) — log a(z|æ)] (20.5) 


Thus if we model the density of z ~ q(z|x), which is a dequantized version of x, we will get a lower 
bound on p(x). 


20.4.1.2 Challenges with using the likelihood 


Unfortunately, there are several challenges with using the likelihood to evaluate generative models, 
some of which we discuss below. 


20.4.1.3 Likelihood can be hard to compute 


For many models, computing the likelihood can be computationally expensive, since it requires know- 
ing the normalization constant of the probability model. One solution is to use variational inference 
(Chapter 10), which provides a way to efficiently compute lower (and sometimes upper) bounds on 
the log likelihood. Another solution is to use annealed importance sampling (Section 11.5.4.1), which 
provides a way to estimate the log likelihood using Monte Carlo sampling. However, in the case of 
implicit generative models, such as GANs (Chapter 26), the likelihood is not even defined, so we 
need to find evaluation metrics that do not rely on likelihood. 


20.4.1.4 Likelihood is not related to sample quality 


A more subtle concern with likelihood is that it is often uncorrelated with the perceptual quality of 
the samples, at least for real-valued data, such as images and sound. In particular, a model can have 
great log-likelihood but create poor samples and vice versa. 

To see why a model can have good likelihoods but create bad samples, consider the following 
argument from [TOB16]. Suppose qo is a density model for D-dimensional data x which performs 
arbitrarily well as judged by average log-likelihood, and suppose qı is a bad model, such as white 
noise. Now consider samples generated from the mixture model 


go(x) = 0.01qo (x) + 0.99q1 (x) (20.6) 


Clearly 99% of the samples will be poor. However, the log-likelihood per pixel will hardly change 
between q2 and qo if D is large, since 


log q2(x) = log[0.01go(x) + 0.99q1 (x)] > log[0.01go(x)] = log go(x) — 100 (20.7) 
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For high-dimensional data, | log go(a)| ~ D >> 100, so log g2(x) ~ log go(x), and hence mixing in the 
poor sampler does not significantly impact the log likelihood. 

Now consider a case where the model has good samples but bad likelihoods. To achieve this, 
suppose q is a GMM centered on the training images: 


1 N 
W(@) = ș S > N (alan, eT) (20.8) 


If € is small enough that the Gaussian noise is imperceptible, then samples from this model will look 
good, since they correspond to the training set of real images. But this model will almost certainly 
have poor likelihood on the test set due to overfitting. (In this case we say the model has effectively 
just memorized the training set.) 


20.4.2 Distances and divergences in feature space 


Due to the challenges associated with comparing distributions in high dimensional spaces, and the 
desire to compare distributions in a semantically meaningful way, it is common to use domain-specific 
perceptual distance metrics, that measure how similar data vectors are to each other or to the 
training data. However, most metrics used to evaluate generative models do not directly compare 
raw data (e.g. pixels) but use a neural network to obtain features from the raw data and compare 
the feature distribution obtained from model samples with the feature distribution obtained from 
the dataset. The neural network used to obtain features can be trained solely for the purpose of 
evaluation, or can be pretrained; a common choice is to use a pretrained classifier (see e.g., [Sal+16; 
Heu+17b; Bin+18; Kyn+19; SSG18a]). 

The Inception Score [Sal+16] measures the average KL divergence between the marginal distri- 
bution of class labels obtained from the samples p(y) = f p(y|#)pe(a) and the distribution p(y|x) 
obtained from a sample x ~ pg(x). This leads to the following score: 


IS = exp [Epo (x) Dux. (p(yl) || pe(y))] (20.9) 


If a model produces high quality samples from all classes in the dataset, then pe(y) will often be 
close to uniform, while p(y|x) should be a sharp distribution corresponding to the class associated 


32 With a; this leads to a high Dgı (p(y|x) || pe(y)) and thus a high Inception Score score. 


The Inception Score solely relies on class labels, and thus does not measure overfitting or sample 
diversity outside the predifined dataset classes. For example, a model which generates one perfect 
example per class would get a perfect Inception Score, despite not capturing the variety of examples 
inside a class, as shown in Figure 20.11a. To address this drawback, the Fréchet Inception Distance 
or FID score [Heu+17b] measures the Fréchet distance between two Gaussian distributions on sets 
of features of a pre-trained classifier. One Gaussian is obtained by passing model samples through a 


39 pretrained classifier, and the other by passing samples the dataset through the same classifier. If we 


assume that the mean and covariance obtained from model features are u, and Xm and those from 
the data are wg and Yq, then the FID is 


FID = [Hm ea Halls F trace( £a + Xm = 2(BaEm)"/?} (20.10) 


44 Since it uses features instead of class logits, the Fréchet distance captures more than modes captured 
45 by class labels, as shown in Figure 20.11b. Unlike the Inception score, a lower score is better since 
46 we want the two distributions to be as close as possible. 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


hs I$ IS 5 |S le Ix In Is E la lsa Is tem 5 IO lœ IN ID o [A lw N e 
A IW IN ie IO iIo lœ IN ID Jo Te Iw N IR IO 


SS ES IS EI Sle (SR ISIS ISB BSS ISIS Is |< 
A IS IA IRIE IIB IB IS BIS IR IR RIS BIE IS 18 lè IS JS IX 


20.4. EVALUATING GENERATIVE MODELS 


a TE iT UE Ue Sg 


Model 


LEL] AA ronm als 


Figure 20.11: (a) Model samples with good (high) inception score are visually realistic. (b) Model samples 
with good (low) FID score are visually realistic and diverse. 


Unfortunately, the Fréchet distance has been shown to have a high bias, with results varying 
widely based on the number of samples used to compute the score. To mitigate this issue, the 
Kernel Inception Distance has been introduced [Bin+18], which measures the squared MMD 
(Section 2.7.3) between the features obtained from the data and features obtained from model samples. 


20.4.3 Precision and recall metrics 


Since the FID only measures the distance between the data and model distributions, it is difficult 
to use it as a diagnostic tool: a bad (high) FID can indicate that the model is not able to generate 
high quality data, or that it puts too much mass around the data distribution, or that the model 
only captures a subset of the data (e.g. in Figure 26.7). Trying to disentangle between these two 
failure modes has been the motivation to seek individual precision (sample quality) and recall (sample 
diversity) metrics in the context of generative models [LPO17; Kyn+19]. (The diversity question is 
especially important in the context of GANs, where mode collapse (Section 26.3.3) can be an issue.) 

A common approach is to use nearest neighbors in the feature space of a pretrained classifier to 
define precision and recall [Kyn+19]. To formalize this, let us define 


1 if dd! € dst. llo- 6'\|5 < llo- NN: (g, D) 


i (20.11) 
0 otherwise 


fe(b, ®) = i 


where © is a set of feature vectors and NN;(¢’, ®) is a function returning the k-th nearest neighbor 
of @’ in ®. We now define precision and recall as follows: 


1 


precision (P model, Paata) = 5 fr (9, Pdata); (20.12) 
| = bed 
model 
recall (model, Pdata) = 5 fr( dQ, ® model); (20.13) 
Ba PEPdata 


Precision and recall are always between 0 and 1. Intuitively, the precision metric measures whether 
samples are as close to data as data is to other data examples, while recall measures whether data 
is as close to model samples as model samples are to other samples. The parameter k controls 
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how lenient the metrics will be — the higher k, the higher both precision and recall will be. As in 
classification, precision and recall in generative models can be used to construct a trade-off curve 
between different models which allows practitioners to make an informed decision regarding which 
model they want to use. 


20.4.4 Statistical tests 


Statistical tests have long been used to determine whether two sets of samples have been generated 
from the same distribution; these types of statistical tests are called two sample tests. Let us 
define the null hypothesis as the statement that both set of samples are from the same distribution. 
We then compute a statistic from the data and compare it to a threshold, and based on this we 
decide whether to reject the null hypothesis. In the context of evaluating implicit generative models 
such as GANs, statistics based on classifiers [Saj+18] and the MMD [Liu+20b] have been used. For 
use in scenarios with high dimensional input spaces, which are ubiquitous in the era of deep learning, 
two sample tests have been adapted to use learned features instead of raw data. 

Like all other evaluation metrics for generative models, statistical tests have their own advantages 
and disadvantages: while users can specify Type 1 error — the chance they allow that the null 
hypothesis is wrongly rejected — statistical tests tend to be computationally expensive and thus 
cannot be used to monitor progress in training; hence they are best used to compare fully trained 
models. 


20.4.5 Challenges with using pretrained classifiers 


While popular and convenient, evaluation metrics that rely on pretrained classifiers (such as IS, FID, 
nearest neighbors in feature space, and statistical tests in feature space) have significant drawbacks. 
One might not have a pretrained classifier available for the dataset at hand, so classifiers trained on 


28 other datasets are used. Given the well known challenges with neural network generalization (see 


Section 17.4), the features of a classifier trained on images from one dataset might not be reliable 
enough to provide a fine grained signal of quality for samples obtained from a model trained on a 


31 different dataset. If the generative model is trained on the same dataset as the pre-trained classifier 


but the model is not capturing the data distribution perfectly, we are presenting the pre-trained 


33 Classifier with out-of-distribution data and relying on its features to obtain score to evaluate our 
34 models. Far from being purely theoretical concerns, these issues have been studied extensively and 


have been shown to affect evaluation in practice [RV19; BS18]. 


— 20.4.6 Using model samples to train classifiers 


39 Instead of using pretrained classifiers to evaluate samples, one can train a classifier on samples from 


conditional generative models, and then see how good these classifiers are at classifying data. For 


41 example, does adding synthetic (sampled) data to the real data help? This is closer to a reliable 


evaluation of generative model samples, since ultimately, the performance of generative models is 


43 dependent on the downstream task they are trained for. If used for semi supervised learning, one 
44 should assess how much adding samples to a classifier dataset helps with test accuracy. If used for 


model based reinforcement learning, one should assess how much the generative model helps with 


46 agent performance. For examples of this approach, see e.g., [SSM18; SSA18; RV19; SS20b]. 
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Figure 20.12: Illustration of nearest neighbors in feature space: in the top left we have the query sample 
generated using BigGAN, and the rest of the images are its nearest neighbors from the dataset. The nearest 
neighbors search is done in the feature space of a pretrained classifier. From Figure 13 of [BDS18]. Used with 
kind permission of Andy Brock. 


20.4.7 Assessing overfitting 


Many of the metrics discussed so far capture the sample quality and diversity, but do not capture 
overfitting to the training data. To capture overfitting, often a visual inspection is performed: a set 
of samples is generated from the model and for each sample its closest K nearest neighbors in the 
feature space of a pretrained classifier are obtained from the dataset. While this approach requires 
manually assessing samples, it is a simple way to test whether a model is simply memorizing the 
data. We show an example in Figure 20.12: since the model sample in the top left is quite different 
than its neighbors from the dataset (remaining images), we can conclude the sample is not simply 
memorised from the dataset. Similarly, sample diversity can be measured by approximating the 
support of the learned distribution by looking for similar samples in a large sample pool — as in the 
pigeonhole principle — but it is expensive and often requires manual human assessment|AZ17]. 

For likelihood-based models — such as variational autoencoders Chapter 21, autoregressive 
models Chapter 22, and normalising flows Chapter 23 — we can assess memorisation by seeing how 
much the log-likelihood of a model changes when a sample is included in the model’s training set or 
not [BW2]1]. 


20.4.8 Human evaluation 


One approach to evaluate generative models is to use human evaluation, by presenting samples from 
the model alongside samples from the data distribution, and ask human raters to compare the quality 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


E IO l% IN IQ o e IW IN Ie 


m= 
= 


IS lè IR ls la le le Is | 


N 
(= 


N 
= 


IS IS 18 R | 


748 


of the samples [Zho+19b]. Human evaluation is a suitable metric if the model is used to create art or 
other data for human display, or if reliable automated metrics are hard to obtain. However, human 
evaluation can be difficult to standardize, hard to automate and can be expensive or cumbersome to 
set up. 
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2 1 Variational autoencoders 


21.1 Introduction 


In this chapter, we discuss generative models of the form 


z ~ po(z) (21.1) 
a|z ~ Expfam(æ|de(z)) (21.2) 


where p(z) is some kind of prior on the latent code z, dg(z) is a deep neural network, known as the 
decoder, and Expfam(a|7) is an exponential family distribution, such as a Gaussian or product of 
Bernoullis. This is called a deep latent variable model or DLVM. When the prior is Gaussian 
(as is often the case), this model is called a deep latent Gaussian model or DLGM. 

Posterior inference (i.e., computing pe(z|x)) is computationally intractable, as is computing the 
marginal likelihood 


po(#) = / po(æ|z)po(z) dz (21.3) 


Hence we need to resort to approximate inference. For most of this chapter, we will use amortized 
inference, which we discussed in Section 10.3.6. This trains another model, gg(z|a), called the 
recognition network or inference network, simultaneously with the generative model to do 
approximate posterior inference. This combination is called a variational autoencoder or VAE 
[KW14; RMW14b; KW19al, since it can be thought of as a probabilistic version of a deterministic 
autoencoder, discussed in Section 16.3.3. 

In this chapter, we introduce the basic VAE, as well as some extensions. Note that the literature 
on VAE-like methods is vast’, so we will only discuss a small subset of the ideas that have been 
explored. 


21.2 VAE basics 


In this section, we discuss the basics of variational autoencoders. 
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p(z) 


Figure 21.1: Schematic illustration of a VAE. From a figure in [Haf18]. Used with kind permission of Danijar 
Hafner. 


21.2.1 Modeling assumptions 


In the simplest setting, a VAE defines a generative model of the form 


po(z, £) = pe(z)pe(a|z) (21.4) 


17 where pọ(z) is usually a Gaussian, and pọ(æ|z) is usually a product of exponential family distributions 
18 (e.g., Gaussians or Bernoullis), with parameters computed by a neural network decoder, dg(z). For 


19 example, for binary observations, we can use 
D 
po(æ|z) = | [| Ber(xalo(do(z)) (21.5) 
d=1 


In addition, a VAE fits a recognition model 


qo(2|x) = qg(zļeġ(x)) ~ po(z|æ) (21.6) 


27 to perform approximate posterior inference. Here qẹ(z|æ) is usually a Gaussian, with parameters 


28 computed by a neural network encoder eg(x): 
qo(2|x) = N (z|u, diag(exp(é))) (21.7) 
(u, £) = eg(a) (21.8) 


where £ = logo. The model can be thought of as encoding the input x into a stochastic latent 


-bottleneck z and then decoding it to approximately reconstruct the input, as shown in Figure 21.1. 


The idea of training an inference network to “invert” a generative network, rather than running 
an optimization algorithm to infer the latent code, is called amortized inference, and is discussed in 


z- Section 10.3.6. This idea was first proposed in the Helmholtz machine [Day+95]. However, that 


paper did not present a single unified objective function for inference and generation, but instead 


~- used the wake-sleep (Section 10.6) method for training. By contrast, the VAE optimizes a variational 


lower bound on the log-likelihood, which means that convergence to a locally optimal MLE of the 
parameters is guaranteed. 

We can use other approaches to fitting the DLGM (see e.g., [Hofl7; DF19]). However, learning an 
inference network to fit the DLGM is often faster and can have some regularization benefits (see e.g., 


= IKP20]).? 


= 1. For example, the website https://github.com/matthewvowels1/Awesome-VAEs lists over 900 papers. 


46 2. Combining a generative model with an inference model in this way results in what has been called a “monference” 
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21.2. VAE BASICS 


21.2.2 Evidence lower bound (ELBO) 


When fitting the model, our goal is to maximize the marginal likelihood 


po(#) = j po(æ|z)po(z) dz (21.9) 


Unfortunately, computing this quantity is intractable. However, we can use an inference network to 

compute an approximate posterior, gg(z|a), and hence a lower bound to the marginal likelihood. This 

idea is discussed in Section 10.1.2, but we repeat the details here, using slightly different notation. 
First note that we have the following decomposition: 


log po (Œ) = E,4(z\2) [log po (æ)] (21.10) 
-tuun [ea (2E BE) a 
-tuon fen SED] +a ($8) ao 

Le fas Dux. (qo (zlæ)|po (zlæ)) 


The second term in Equation (21.13) is non-negative, and hence 


Lo, (£) < log po (x) (21.14) 


The quantity log pọ(æ) is the log marginal likelihood, also called the evidence. Hence Lo,¢(x) is 
called the evidence lower bound or ELBO. 
We can rewrite the ELBO in 3 equivalent ways as follows: 


Lo, (£) = Eq, (z1x) [log po (æ, z) — log qẹ(z|æ)] (21.15) 
= Ey, (z|æ) [log po (x|z) + log pe(z)] + H(q¢(z|æ)) (21.16) 
= Eq, (z\x) [log po (z|z)] — Dex (4¢(2|2) || po(z)) (21.17) 


We can interpret this last objective as the expected log likelihood plus a regularization term, that 
ensures the (per-sample) posterior is “well behaved” (does not deviate too far from the prior in terms 
of KL divergence). 

The tightness of this lower bound is controlled by the variational gap, which is given by 
Dxt (q¢(z|x) || pe(z|a)). A better approximate posterior results in a tighter bound. When the 
KL goes to zero, the posterior is exact, so any improvements to the ELBO directly translate to 
improvements in the likelihood of the data, as in the EM algorithm (see Section 6.6.3). 

An alternative to maximizing the ELBO is to minimize the negative ELBO, also called the 
variational free energy, given by 


LO, p; £) = Eq, (z\e) [— log po(x|2)] + Dri (e(2|@) || po(z)) (21.18) 


i.e., model-inference hybrid. See the blog by Jacob Andreas, http://blog. jacobandreas.net/monference.html, for 
further discussion. 
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21.2.3 Evaluating the ELBO 


We can approximate the ELBO by sampling from the posterior, Zs ~ qg(z|x), and then evaluating 
the expectation term using a Monte Carlo estimate, 


cag(zle) log po(æ, 2)] ~ = Doer (£, zs) (21.19) 


We can often compute the (differential) entropy term H(q¢g(z|x)) analytically, depending on the 
choice of variational distribution. For example, if we use a Gaussian posterior, q(z|w) = N(z|u, ©), 
we can use Equation (5.95) to compute the entropy: 


1 
H(qo(z|@)) = 5 In [E] + const (21.20) 


Similarly we can often compute the KL term Dxx (qẹ(z|æ) || pe(z)) analytically. For example, if 
we assume a diagonal Gaussian prior, p(z) = V(z|0,1), and diagonal gaussian posterior, q(z|æ) = 
N(z|p, diag(o)), we can use Equation (5.81) to compute the KL in closed form: 


K 
1 
Dx (allp) = -3 X [logok — ok — we + 1] (21.21) 
k=1 


where K is the number of latent dimensions. 
In cases where we cannot evaluate the entropy or KL in closed form, we can use just use Monte 
Carlo to approximate all terms: 


s 
1 
Log (x a> [log pe(w|zs) + log pe(zs) — log gg(z5|x)] (21.22) 


We often use a single MC sample, so S = 1. 


33 21.2.4 Optimizing the ELBO 
— The ELBO for a single datapoint x is given in Equation (21.17). The ELBO for the whole dataset, 


scaled by N = |D|, the number of examples, is given by 


1 ` 
Leo, (D D Lo (En) = = Ta [ lge (z|æn) log pe(#n|z) + log pe(z) — log a¢(z|@n)]] 
N ee NEn 


(21.23) 


42 Our goal is to maximize this wrt 0 and @: the former fits the model to the data, the latter reduces 
43 the KL gap between the approximate and true posterior. 


We can create an unbiased minibatch approximation of this objective by sampling examples x 


45 and then computing the objective for a given x. So now we focus on a fixed æ, for brevity, and drop 
46 the sum over £n. 
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21.2. VAE BASICS 


Reparameterized form 


| 


f 
~ qg(z|x) V, f <2 = g(ġ,x,£) 
x 


“ `o 
Vof Q ~ ple) 


——> : Evaluation of f 


Original form 


fi Backprop 


: Deterministic node 


v : Random node 


Figure 21.2: Illustration of the reparameterization trick. The objective f depends on the variational parameters 
Q, the observed data x, and the latent random variable z ~ qẹ(z|æ£). On the left, we show the standard form 
of the computation graph. On the right, we show a reparameterized form, in which we move the stochasticity 
into the noise source €, and compute z deterministically, z = g(ġ,x,€). The rest of the graph is deterministic, 
so we can backpropagate the gradient of the scalar f wrt @ through z and into @. From Figure 2.3 of [KW19a]. 
Used with kind permission of Durk Kingma. 


=» : Differentiation of f 


The gradient wrt the generative parameters @ is easy to compute, since we can push gradients 
inside the expectation, and use a single Monte Carlo sample: 


Volo, (£) = VoEq,(z\z) [log pe(x, z) — log qo(z|x)] (21.24) 
= gg (zx) [Ve {log pe (x, z) = log qo(z|æ)}] (21.25) 
= Vo log po(æ, 2°) (21.26) 


where zê ~ qg(z|a). This is an unbiased estimate of the gradient, so can be used with SGD. 
The gradient wrt the inference parameters @ is harder to compute since 


(21.27) 
(21.28) 


V ob0,6(£) = V Eq (zlæ) [log pe(x, z) — log qg(z|x)] 
# qo (z|x) [Ve {log pe (x, z) z log do(z|x)}] 


However, we can often use the reparameterization trick, which we discuss in Section 21.2.5. (If not, 
we can use black box VI, which we discuss in Section 10.3.2.) 


21.2.5 Using the reparameterization trick to compute ELBO gradients 


In this section, we discuss the reparameterization trick for taking gradients wrt distributions over 
continuous latent variables z ~ qg(z|a). We explain this in detail in Section 6.5.4, but we summarize 
the basic idea here. 
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The key trick is to rewrite the random variable z ~ qg(z|a) as some differentiable (and invertible) 
transformation r of another random variable € ~ p(e), which does not depend on @, i.e., we assume 
we can write 


z =r(e, b,x) (21.29) 
For example, 

z~N(p,diag(o)) => z=pu+600, e~ N(0,1) (21.30) 
Using this, we we have 

ae (z|x) Lf (2)] = Epee) [f(2)] s-t. z= r(e, p, æ) (21.31) 
where we define 

f(z) = log p(x, z) — log qẹ(z|æ) (21.32) 
Hence 

V gE gy (zla) [f(2)] = VoEpce) [f(2)] = Epe) [Vef (2) (21.33) 


which we can approximate with a single Monte Carlo sample. This lets us propagate gradients back 
through the f function and then into the DNN transformation function r that is used to compute 
z = r(e, ġ, x). See Figure 21.2 for an illustration. 

Since we are now working with the random variable €, we need to use the change of variables 
formula to compute 


ð 
log qẹ(z|x) = log p(€) — log |det (Z) (21.34) 
where os is the Jacobian: 
an lags a 
ð €1 Ek 
A e (21.35) 
: Oz  .,, zk 
Oey Ock 


We design the transformation z = r(e) such that this Jacobian is tractable to compute. We give 
some examples below. 


21.2.5.1 Fully factorized Gaussian 


= Suppose we have a fully factorized Gaussian posterior: 


e ~ N (0,1) (21.36) 
z= un+00e€ (21.37) 
(u, logo) = eg(2) (21.38) 
43 Then the Jacobian is a = diag(o), so 
K K 4 1 
log qg(z|x) = 2 NaN 1) — log ok = 2 -5 log(2a) — zék — log ox, (21.39) 
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21.2.5.2 Full covariance Gaussian 

Now consider a full covariance Gaussian posterior: 
e ~ N(0,1) (21.40) 
z= pu + Le (21.41) 


where L is a lower triangular matrix with non-zero entries on the diagonal, which satisfies © = LL'. 
The Jacobian of this affine transformation is Cz = L. Since L is a triangular matrix, its determinant 


is the product of its main diagonal, so 


log 


K 
Oz 
det —| = l Lk 21.42 
a Soe (21.42) 


We need to make the parameters of the transformation r be a function of the inputs æ. One way 
to do this is to define 


(11,logo,L!) = eg (a) (21.43) 
L= MOL’ +diag(c) (21.44) 


where M is a masking matrix with Os on and above the diagonal, and 1s below the diagonal. With 
this construction, the diagonal entries of L are given by ø, so 


K K 
Oz 
log |det a E X log |Lkk| = 5 log ok (21.45) 
k=1 k=1 
See Algorithm 37 for the corresponding pseudo code for computing the reparameterized ELBO. 
Algorithm 37: Computing a single sample unbiased estimate of the reparameterized ELBO 
for a VAE with full covariance Gaussian posterior, and factorized Bernoulli likelihood. Based 
on Algorithm 2 of [KW19a]. 
1 (u,logo,L') = e4(2) 
2 M = np.triu(np.ones( K), —1) 
3 L= MOL + diag(c) 
a e~ N(0,1) 
5 z=Le+yp 
6 p= dolz) 
7 Liogaz = — } k-i [ae + 5 log(2m) + log or] // from q¢(z|x) 
K 
8 Liogpz = Sk [32 Te 5 log(27)| // from pelz) 
9 Liogpx = — Dai [zalog pa + (1 — za) log(1 — pa)] // from po(æ|z) 
10 L= Liogpx a Liogpz = Liogaz 
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Figure 21.3: Illustration of unconditional image generation using (V)AEs trained on CelebA. Row 1: 
Deterministic autoencoder. Row 2: B-VAE with 8 = 0.5. Row 3: VAE (with 8 = 1). Generated by 
celeba_vae_ae_comparison.ipynb. 


21.2.5.3 Inverse autoregressive flows 


2. Tn Section 10.4.3, we discuss how to use inverse autoregressive flows to learn more expressive posteriors 
28 q¢(z\a), leveraging the tractability of the Jacobian of this nonlinear transformation. 


— 21.2.6 Comparison of VAEs and autoencoders 


32 VAEs are very similar to deterministic autoencoders (AE). There are 2 main differences: in the AE, 
33 the objective is the log likelihood of the reconstruction without any KL term; and in addition, the 
34 encoding is deterministic, so the encoder network just needs to compute E[z|x] and not V [z|æ]. In 
35 view of these similarities, one can use the same codebase to implement both methods. However, it 
36 is natural to wonder what the benefits and potential drawbacks of the VAE are compared to the 


deterministic AE. 
We shall answer this question by fitting both models to the CelebA dataset. Both models have the 


39 same convolutional structure with the following number of hidden channels per convolutional layer in 


the encoder: (32, 64, 128, 256, 512). The spatial size of each layer is as follows: (32, 16, 8, 4, 2). The 


41 final 2 x 2 x 512 convolutional layer then gets reshaped and passed through a linear layer to generate 
42 the mean and (marginal) variance of the stochastic latent vector, which has size 256. The structure 
43 of the decoder is the mirror image of the decoder. Each model is trained for 5 epochs with a batch 
44 size of 256, which takes about 20 minutes on a GPU. 


The main advantage of a VAE over a deterministic autoencoder is that it defines a proper generative 


46 model, that can create sensible-looking novel images by decoding prior samples z ~ N(0,1). By 
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Figure 21.4: Illustration of image reconstruction using (V)AEs trained and applied to CelebA. Row 1: 
Original images. Row 2: Deterministic autoencoder. Row 3: B-VAE with B = 0.5. Row 4: VAE (with B = 1). 
Generated by celeba_vae_ae_comparison.ipynb. 


contrast, an autoencoder only knows how to decode latent codes derived from the training set, so 
does poorly when fed random inputs. This is illustrated in Figure 21.3. 

We can also use both models to reconstruct a given input image. In Figure 21.4, we see that both 
AE and VAE can reconstruct the input images reasonably well, although the VAE reconstructions are 
somewhat blurry, for reasons we discuss in Section 21.3.1. We can reduce the amount of blurriness 
by scaling down the KL penalty term by a factor of 6; this is known as the 6-VAE, and is discussed 
in more detail in Section 21.3.1. 


21.2.7 VAEs optimize in an augmented space 


In this section, we derive several alternative expressions for the ELBO which shed light on how VAEs 
work. 
First, let us define the joint generative distribution 


poe(x, z) = pe(z)pe(x|z) (21.46) 
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from which we can derive the generative data marginal 
pe(x) = f vote. z)dz (21.47) 
and the generative posterior 


pe(z|x) = po(x, z)/pe(x) (21.48) 


Let us also define the joint inference distribution 


qD,o(2,£) = po(x)qg(z|x) (21.49) 
where 
1 N 
ppls) = = >) o(@n — x) (21.50) 


is the empirical distribution. From this we can derive the inference latent marginal, also called the 
aggregated posterior: 


qD,¢(2) = I qD, (£, z)dæ (21.51) 
and the inference likelihood 


qp,¢(2|z) = 90,4(#, Zz)/4D,$(2) (21.52) 


See Figure 21.5 for a visual illustration. 
Having defined our terms, we can now derive various alternative versions of the ELBO, following 


30 [ZSE19]. First note that the ELBO averaged over all the data is given by 


Lod = Epp (x) [Egs (zle) [log po(ælz)]] — Epp) [Dux (a(zlæ) || pe())] (21.53) 
= Egy 4 (æ,z) [log pe(x|z) + log pa(z) — log qg(z|x)| (21.54) 
~ polz, z) 
=E log 5 +l 21. 
ap. (2z) [108 Tez) + log pp(x) (21.55) 
= —Dxt (90,9(@, z) || po(x, z)) + Epp) [log pp (æ)] (21.56) 


39 If we define £ to mean equal up to additive constants, we can rewrite the above as 


Lo, = -Dru (qelz, z) || po(@, z)) (21.57) 
= —Dxu (pp(#) || po(£)) — Epo (œ) [Dex (¢¢(2|2) || po(z|£))] (21.58) 


Thus maximizing the ELBO requires minimizing the two KL terms. The first KL term is minimized 


45 by MLE, and the second KL term is minimized by fitting the true posterior. Thus if the posterior 
46 family is limited, there may be a conflict between these objectives. 
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qD.(x,z) = qD(x) de(zlx) 


Marginal: qo(z) 


pe(x.2) = polz) pe(x!z) 


( 
Prior distribution: pe(z) 
) 


ý 


- 


Decoder: pe(x|z 


Z-space 


Encoder: q¢(z|x) 


X-Space 


Data distribution: qp(x) Marginal: pe(x) 


Figure 21.5: The maximum likelihood (ML) objective can be viewed as the minimization of Dx (pp(2) || pe(x)). 
(Note: in the figure, pp(x) is denoted by qp(a).) The ELBO objective is minimization of 
Dri (¢v,(#, z) || pe(a, z)), which upper bounds Dx (¢p(x) || pe(x)). From Figure 2.4 of [KW19a]. Used 
with kind permission of Durk Kingma. 


Finally, we note that the ELBO can also be written as 


Leg = -Dru (90,6(2) || po(2)) — Eav o(z) [Dx (94(#l2) || po(ælz))] (21.59) 


We see from Equation (21.59) that VAEs are trying to minimize the difference between the inference 
marginal and generative prior, Dg (¢¢(Z) || pe(z)), while simultaneously minimizing reconstruction 
error, Dx (q¢(#|Z) || po(x|z)) Since x is typically of much higher dimensionality than z, the latter 
term usually dominates. Consequently, if there is a conflict between these two objectives (e.g., due to 
limited modeling power), the VAE will favor reconstruction accuracy over posterior inference. Thus 
the learned posterior may not be a very good approximation to the true posterior (see [ZSE19] for 
further discussion). 


21.3 VAE generalizations 
In this section, we discuss some variants of the basic VAE model. 
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21.3.1 B-VAE 


It is often the case that VAEs generate somewhat blurry images, as illustrated in Figure 21.4, 
Figure 21.3 and Figure 20.9. This is not the case for models that optimize the exact likelihood, such 
as pixelCNNs (Section 22.3.2) and flow models (Chapter 23). To see why VAEs are different, consider 
the common case where the decoder is a Gaussian with fixed variance, so 


1 
log pe(a|z) = -z222 — do(z)||5 + const (21.60) 


Let eg(x) = E [qo (z|x)] be the encoding of x, and ¥ (z) = {x : eg(x) = z} be the set of inputs that 
get mapped to z. For a fixed inference network, the optimal setting of the generator parameters, 
when using squared reconstruction loss, is to ensure dg(z) = E [æ : x € X(z)]. Thus the decoder 
should predict the average of all inputs x that map to that z, resulting in blurry images. 

We can solve this problem by increasing the expressive power of the posterior approximation 
(avoiding the merging of distinct inputs into the same latent code), or of the generator (by adding 
back information that is missing from the latent code), or both. However, an even simpler solution is 
to reduce the penalty on the KL term, making the model closer to a deterministic autoencoder: 

£o(8, ble) = —Ey4(z\n) log po(alz)] +8 Dux (do(zl@) || po(2)) (21.61) 
ST ——n aa 


_ 


Le LR 


where Lp is the reconstruction error (negative log likelihood), and £p is the KL regularizer. This is 
called the 6-VAE objective [Hig+17a]. If we set 6 = 1, we recover the objective used in standard 
VAEs; if we set 6 = 0, we recover the objective used in standard autoencoders. 

By varying £ from 0 to infinity, we can reach different points on the rate distortion curve, as 
discussed in Section 5.4.2. These points make different tradeoffs between reconstruction error (distor- 
tion) and how much information is stored in the latents about the input (rate of the corresponding 
code). By using 6 < 1, we store more bits about each input, and hence can reconstruct images in a 
less blurry way. If we use > 1, we get a more compressed representation. 


31 21.3.1.1 Disentangled representations 


One advantage of using 8 > 1 is that it encourages the learning of a latent representation that is 
= “disentangled”. Intuitively this means that each latent dimension represents a different factor of 
= variation in the input. This is often formalized in terms of the total correlation (Section 5.3.5.1), 


= which is defined as follows: 


TC(z) = JOH (2x) — H (2) = De (me | To] (21.62) 
k k 


40 This is zero iff the components of z are all mutually independent, and hence disentangled. In [AS18], 
41 they prove that using 6 > 1 will decrease the TC. 


Unfortunately, in [Loc+18] they prove that nonlinear latent variable models are unidentifiable, and 


43 therefore for any disentangled representation, there is an equivalent fully entangled representation 
44 with exactly the same likelihood. Thus it is not possible to recover the correct latent representation 
45 without choosing the appropriate inductive bias, via the encoder, decoder, prior, dataset, or learning 
46 algorithm, i.e., merely adjusting 8 is not sufficient. See Section 32.4.1 for more discussion. 
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21.3.1.2 Connection with information bottleneck 


In this section, we show that the 6-VAE is an unsupervised version of the information bottleneck 
(IB) objective from Section 5.6. If the input is a, the hidden bottleneck is z, and the target outputs 
are £, then the unsupervised IB objective becomes 


Lup = 8 (z; x) = I(z; x) (21.63) 
Le p(z,z) | p . p(z, &) 
= PEp(æ,z) og Fer] Up (z,é) og Fes (21.64) 
where 
plo, 2) = po(e)p(zla) (21.65) 
p(z,é) = I po(«)p(z|@)p(@lz)ax (21.66) 


Intuitively, the objective in Equation (21.63) means we should pick a representation z that can 
predict x reliably, while not memorizing too much information about the input æ. The tradeoff 
parameter is controlled by £. 

From Equation (5.180), we have the following variational upper bound on this unsupervised 
objective: 


Luvip = —Egp 4(z,«) [log pe(#|2)| + Epp (e) [Dri (ao (zlæ) || po(2))] (21.67) 


which matches Equation (21.61) when averaged over æ. 


21.3.2 InfoVAE 


In Section 21.2.7, we discussed some drawbacks of the standard ELBO objective for training VAEs, 
namely the tendency to ignore the latent code when the decoder is powerful (Section 21.4), and the 
tendency to learn a poor posterior approximation due to the mismatch between the KL terms in 
data space and latent space (Section 21.2.7). We can fix these problems to some degree by using a 
generalized objective of the following form: 


L(0, pla) = —ADux. (49(Z) || po(2)) — Eag(z) [Px (e(#|2) || po(ælz))] + a Ty (æ; z) (21.68) 


where a > 0 controls how much we weight the mutual information I,(x;z) between æ and z, and 
A > 0 controls the tradeoff between z-space KL and a-space KL. This is called the InfoVAE objective 
[ZSE19]. If we set a = 0 and A = 1, we recover the standard ELBO, as shown in Equation (21.59). 

Unfortunately, the objective in Equation (21.68) cannot be computed as written, because of the 
intractable MI term: 


eae aez) J e h 0) 
I,( ’ ) do (x,z) og Ben] q (x,z) f g se | (21.69) 
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However, using the fact that q¢(x|z) = ppo(x)qe(z|x)/qe(z), we can rewrite the objective as follows: 


= qo(2) (12) 0 og 4042) 

b= Bolo) | aloe paie) polalz) CE e] ere 
= pennies qo(z)***"'pp(a) 
= Egy («,2) h gpo(z|z) — log ATE = (21.71) 
= Epp (a) [Eq4(z\2) [log pe(x|2)]] — (1 — @)Epp(e) [Dux (4g (2|) || po(2))] 
— (a +à — 1)Drı (9¢(2) || po(2)) — Epp (a) [log pp (x)] (21.72) 


where the last term is a constant we can ignore. The first two terms can be optimized using the 
reparameterization trick. Unfortunately, the last term requires computing qg(z) = f, qo(a, z)da, 
which is intractable. Fortunately, we can easily sample from this distribution, by sampling x ~ pp(æ) 
and z ~ q¢(z|x). Thus qg(z) is an implicit probability model, similar to a GAN (see Chapter 26). 

As long as we use a strict divergence, meaning D(q, p) = 0 iff q = p, then one can show that this 
does not affect the optimality of the procedure. In particular, proposition 2 of [ZSE19] tells us the 
following: 


Theorem 1. Let X and Z be continuous spaces, and a < 1 (to bound the MI) and A > 0. For any 
fixed value of 1,(x;z), the approzimate InfoVAE loss, with any strict divergence D(q¢(z), pe(Z)), is 
globally optimized if pe(x) = pp(a) and qg(z|x) = pe(z|x). 

21.3.2.1 Connection with MMD VAE 


If we set a = 1, the InfoVAE objective simplifies to 


L = Epp (æ) (Eqg(zle) [log po (a|z)]] — ADxx (a(z) I| po(z)) (21.73) 


The MMD VAE?’ replaces the KL divergence in the above term with the (squared) maximum mean 


30 discrepancy or MMD divergence defined in Section 2.7.3. (This is valid based on the above theorem.) 


The advantage of this approach over standard InfoVAE is that the resulting objective is tractable. In 
particular, if we set A = 1 and swap the sign we get 


L = Epo (x) [Eq4(z\2) l- log pe(a|z)]] + MMD(a¢(z), pa(z)) (21.74) 


~ As we discuss in Section 2.7.3, we can compute the MMD as follows: 


MMD(p, q) = Sp(z), p(z’) [K(z, 2’)] + Lolz) az") [K(z, z')] —2 lataata?) [K(z, 2’)] (21.75) 


1 
202 


||z—z'||3). Intuitively 


In practice, we can implement the MMD objective by using the posterior predicted mean z, = 


33 ep(£n) for all B samples in the current minibatch, and comparing this to B random samples from 
44 the N(0,1) prior. 


46 3. Proposed in https: //ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/. 
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Figure 21.6: Ilustration of multi-modal VAE. (a) The generative model with N = 2 modalities. (b) The 
product of experts (PoE) inference network is derived from N individual Gaussian experts Ei. po and oo are 
parameters of the prior. (c) If a modality is missing, we omit its contribution to the posterior. From Figure 1 
of [WG18]. Used with kind permission of Mike Wu. 


If we use a Gaussian decoder with fixed variance, the negative log likelihood is just a squared error 
term: 


— log pe(a|z) = ||x — de(z)||2 (21.76) 


Thus the entire model is deterministic, and just predicts the means in latent space and visible space. 


21.3.2.2 Connection with G-VAEs 

If we set a = 0 and A = 1, we get back the original ELBO. If A > 0 is freely chosen, but we use 
a = 1 — À, we get the 6-VAE. 

21.3.2.3 Connection with adversarial autoencoders 


If we set a = 1 and \ = 1, and D is chosen to be the Jensen Shannon divergence (which can be 
minimized by training a binary discriminator, as explained in Section 26.2.2), then we get a model 
known as an adversarial autoencoder |Mak+15a]. 


21.3.3 Multi-modal VAEs 


It is possible to extend VAEs to create joint distributions over different kinds of variables, such as 
images and text. This is sometimes called a multimodal VAE or MVAE. Let us assume there are 
M modalities. We assume they are conditionally independent given the latent code, and hence the 
generative model has the form 


M 
po(ai,..-,@a,2) = p(z) |] po(amlz) (21.77) 
m=1 


where we treat p(z) as a fixed prior. See Figure 21.6(a) for an illustration. 
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The standard ELBO is given by 


Lo,6(X) = Egg (z|x) £ loero(enl2) — Dix (46(2|X) || p(=)) (21.78) 
where X = (a1,...,% 7) is the observed data. However, the different likelihood terms p(£m|z) may 


have different dynamic ranges (e.g., Gaussian pdf for pixels, and categorical pmf for text), so we 
introduce weight terms Am > 0 for each likelihood. In addition, let 8 > 0 control the amount of KL 
regularization. This gives us a weighted version of the ELBO, as follows: 


Lo, (X) = Eqy(2|x) £ Xm eran) — Dru (qo(2|X) || p(z)) (21.79) 


Often we don’t have a lot of paired (aligned) data from all M modalities. For example, we may 
have a lot of images (modality 1), and a lot of text (modality 2), but very few (image, text) pairs. 
So it is useful to generalize the loss so it fits the marginal distributions of subsets of the features. Let 
Om = 1 if modality m is observed (i.e., £m is known), and let Om = 0 if it is missing or unobserved. 
Let X = {£m : Om = 1} be the visible features. We now use the following objective: 


Le,o(X) = ceo | 5 got) - BDza (ag(21X) || p(2)) (21.80) 


m:Om=1 


The key problem is how to compute the posterior gg(z|X) given different subsets of features. In 
general this can be hard, since the inference network is a discriminative model that assumes all 
inputs are available. For example, if it is trained on (image, text) pairs, gg(z|xv1, £2), how can we 
compute the posterior just given an image, qg(z|x1), or just given text, qo(2|a2)? (This issue arises 
in general with VAE when we have missing inputs; we discuss the general case in Section 21.3.4.) 

Fortunately, based on our conditional independence assumption between the modalities, we can 


31 compute the optimal form for gg(z|X) given set of inputs by computing the exact posterior under 


the model, which is given by 


EERE ope 
p(2|X) E pl£i,..., £M) E FA £M) “II p( mlz) (aue 
o P) gI PCZ|@m)p(@m) 
= A ee ay) H P (21.82) 
M plz|£m) 
x p(z) II ay n x » JI q (z|2m) (21.83) 


43 This can be viewed as a product of experts (Section 24.1.1), where each @(z|@m) is an “expert” for 


the m’th modality, and p(z) is the prior. We can compute the above posterior for any subset of 


45 modalities for which we have dare by modifying the product over m. If we use Gaussian distributions 
46 for the prior p(z) = N (z|uo, Ag) and marginal posterior ratio ¢(z|a@m) = N (z|Hm, Aj), then we 
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(a) Missingness as latents 


(b) MCAR corruption process (c) MNAR corruption process 


Figure 21.7: Illustration of different VAE variants for handling missing data. From Figure 1 of [CNW20]. 
Used with kind permission of Mark Collier. 


can compute the product of Gaussians using the result from Equation (2.115): 


M 
I] Melem AR) x N (zlu, E), 2 =O Am), u= EX | Anta) (21.84) 
m=0 m 


m 


Thus the overall posterior precision is the sum of individual expert posterior precisions, and the 
overall posterior mean is the precision weighted average of the individual expert posterior means. 
See Figure 21.6(b) for an illustration. For a linear Gaussian (factor analysis) model, we can ensure 
q(z|£m) = p(z|@m), in which case the above solution is the exact posterior [WN18], but in general it 
will be an approximation. 

We need to train the individual expert recognition models q(z|@,) as well as the joint model 
q(z|X), so the model knows what to do with fully observed as well as partially observed inputs at 
test time. In [Ved+18], they propose a somewhat complex “triple ELBO” objective. In [WG18], they 
propose the simpler approach of optimizing the ELBO for the fully observed feature vector, all the 
marginals, and a set of J randomly chosen joint modalities: 


M 
Le,g(X) = heg(a1,...,¢) + >> beg(@m) + X be,o(X,) (21.85) 
m=1 GET 


This generalizes nicely to the semi-supervised setting, in which we only have a few aligned 
(“labeled”) examples from the joint, but have many unaligned (“unlabeled”) examples from the 
individual marginals. See Figure 21.6(c) for an illustration. 

Note that the above scheme can only handle the case of a fixed number of missingness patterns; 
we generalize to allow for arbitrary missingess in Section 21.3.4. 


21.3.4 VAEs with missing data 


Sometimes we may have missing data, in which parts of the data vector x € R? may be unknown. 
In Section 21.3.3 we saw a special case of this when we discussed multimodal VAEs. In this section 
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we allow for arbitrary patterns of missingness. 

To model the missing data, let m € {0,1}? be a binary vector where m; = 1 if zj is missing, and 
mj = 0 otherwise. Let X = {x} and M = {m\)} be N x D matrices. Furthermore, let X, be 
the observed parts of X and X, be the hidden parts. If we assume p(M|X,, X;,) = p(M), we say the 
data is missing completely at random or MCAR, since the missingness does not depend on the 
hidden or observed features. If we assume p(M|X., Xa) = p(M|X.), we say the data is missing at 
random or MAR, since the missingness does not depend on the hidden features, but may depend 
on the visible features. If neither of these assumptions hold, we say the data is not missing at 
random or NMAR. 

In the MCAR and MAR cases, we can ignore the missingness mechanism, since it tells us nothing 
about the hidden features. However, in the NMAR case, we need to model the missing data 
mechanism, since the lack of information may be informative. For example, the fact that someone 
did not fill out an answer to a sensitive question on a survey (e.g., “Do you have COVID?”) could be 
informative about the underlying value. See e.g., [LR87; Mar08] for more information on missing 
data models. 

In the context of VAEs, we can model the MCAR scenario by treating the missing values as latent 
variables. This is illustrated in Figure 21.7(a). Since missing leaf nodes in a directed graphical model 
do not affect their parents, we can simply ignore them when computing the posterior p(z|a), 
where a) are the observed parts of example 7. However, when using an amortized inference network, 
it can be difficult to handle missing inputs, since the model is usually trained to compute p(z x"). 
One solution to this is to use the product of experts approach discussed in the context of multi-modal 
VAEs in Section 21.3.3. However, this is designed for the case where whole blocks (corresponding to 
different modalities) are missing, and will not work well if there are arbitrary missing patterns (e.g., 
pixels that get dropped out due to occlusion or scratches on the lens). In addition, this method will 


2 not work for the NMAR case. 


An alternative approach, proposed in [CNW20], is to explicitly include the missingness indicators 


= into the model, as shown in Figure 21.7(b). We assume the model always generates each a, for 
= j= 1:d, but we only get to see the “corrupted” versions £j. If mj = 0 then x; = æj, but if m; = 1, 
= then x; is a special value, such as 0, unrelated to æj. We can model any correlation between the 
2 missingness elements (components of m) by using another latent variable zm. This model can easily 
32 be extended to the NMAR case by letting m depend on the latent factors for the observed data, z, 
22 as well as the usual missingess latent factors Zm, as shown in Figure 21.7(c). 


We modify the VAE to be conditional on the missingness pattern, so the VAE decoder has the 


2 form p(£o|z, m), and the encoder has the form q(z|a ,m). However, we assume the prior is p(z) 
2 as usual, independent of m. We can compute a lower bound on the log marginal likelihood of the 
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Original (x) Mask (m) Corrupted (x) VAE 


EOSS 


Figure 21.8: Imputing missing pixels given a masked out image using a VAE using a MCAR assumption. 
From Figure 2 of [CNW20]. Used with kind permission of Mark Collier. 


observed data, given the missingness, as follows: 


log p(ao|m) = 


> 


log | | olæo,2mlz, m)p(z)dændz 
log | p(æo|z,m)p(2)dz 


log [ radz m)p(z) dzem) 


, p(z) 
log Eg(z\z.m £o|zZ,m) —_——_~ 


2 q(z|@,m) [log p(#o|2,m)] — Dri (q(2|@, m) || p()) 


(21.86) 
(21.87) 
(21.88) 


(21.89) 


(21.90) 


We can fit this model in the usual way. See Figure 21.8 for an example. 


21.3.5 Semi-supervised VAEs 


In this section, we discuss how to extend VAEs to the semi-supervised learning setting in which 
we have both labeled data, Dz = {(@n,yn)}, and unlabeled data, Dy = {(a@n)}. We focus on the 
M2 model, proposed in [Kin+14al]. 


The generative 


model has the following form: 


po(,y) = pa(u)pe(ly) = poly) | relly. z)po(z)dz 


(21.91) 


where z is a latent variable, pe(z) = N (z|0,T) is the latent prior, pg(y) = Cat(y|7) the label prior, 
and pe(xly, z) = p(a| fe(y, z)) is the likelihood, such as a Gaussian, with parameters computed by f 
(a deep neural network). The main innovation of this approach is to assume that data is generated 
according to both a latent class variable y as well as the continuous latent variable z. The class 
variable y is observed for labeled data and unobserved for unlabled data. 
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To compute the likelihood for the labeled data, pe(x, y), we need to marginalize over z, which we 
can do by using an inference network of the form 


qo(2|y,v) = N(z|ug(y, £), diag(og(y, £)) (21.92) 


We then use the following variational lower bound 


log po (x, y) > Eq,(z\x,y) log pe(xly, z) + log po (y) + log po(z) — log qg(z|x,y)] = —L(x,y) (21.93) 


as is standard for VAEs (see Section 21.2). The only difference is that we observe two kinds of data: 
x and y. 

To compute the likelihood for the unlabeled data, pọ(x), we need to marginalize over z and y, 
which we can do by using an inference network of the form 


qo(Z, yl@) = qg(z|x)go (ule) (21.94) 
qo(2|@) = N (z|ug(£), diag(og(x)) (21.95) 
do(yle) = Cat(ylre(x)) (21.96) 


Note that q¢(y|x) acts like a discriminative classifier, that imputes the missing labels. We then use 
the following variational lower bound: 


log po (x) > Eqy(z,y\x) [log pe(xly, z) + log po (y) + log pe(z) — log go(z, ylx)] (21.97) 
= — So de(ylx)L(@, y) +H (qo(ylx)) = -U (x) (21.98) 


Note that the discriminative classifier qg(y|a) is only used to compute the log-likelihood of the 
unlabeled data, which is undesirable. We can therefore add an extra classification loss on the 


28 supervised data, to get the following overall objective function: 


LO) = Ea,y)~d, [LE y)] + Ex~dy (U(#)| + oF (e,y)~p1 [- log ag(y|@)] (21.99) 


_. where Dz is the labeled data, Dy is the unlabeled data, and a is a hyperparameter that controls the 


relative weight of generative and discriminative learning. 


21.3.6 WVAEs with sequential encoders/decoders 


2 In this section, we discuss VAEs for sequential data, such as text and biosequences, in which the 


data x is a variable-length sequence, but we have a fixed-sized latent variable z € R. (We consider 
the more general case in which z is a variable-length sequence of latents — known as sequential 


= VAE or dynamic VAE — in Section 29.13.) All we have to do is modify the decoder p(a#|z) and 


encoder q(z|x) to work with sequences. 


— 21.3.6.1 Models 


If we use an RNN for the encoder and decoder of a VAE, we get a model which is called a VAE-RNN, 


45 as proposed in [Bow+ 16a]. In more detail, the generative model is p(z, #1.7) = p(z)RNN(a1-7|z), 
46 where z can be injected as the initial state of the RNN, or as an input to every time step. The 
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Normal(0, 1) Output Sequence X' conditionally generated from z 
sample 
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Figure 21.9: Illustration of a VAE with a bidirectional RNN encoder and a unidirectional RNN decoder. The 
output generator can use a GMM and/or softmaz distribution. From Figure 2 of [HE18]. Used with kind 
permission of David Ha. 


inference model is q(z|a1.7) = N(z|u(h), E(h)), where h = [h7} , h$ ] is the output of a bidirectional 
RNN applied to 21.7. See Figure 21.9 for an illustration. 

More recently, people have tried to combine transformers with VAEs. For example, in the Optimus 
model of [Li+20], they use a BERT model for the encoder. In more detail, the encoder q(z|æ) is 
derived from the embedding vector associated with a dummy token corresponding to the “class label” 
which is appended to the input sequence x. The decoder is a standard autoregressive model (similar 
to GPT), with one additional input, namely the latent vector z. They consider two ways of injecting 
the latent vector. The simplest approach is to add z to the embedding layer of every token in the 
decoding step, by defining h; = h; + Wz, where h; € RË is the original embedding for the ith 
token, and W € R”** is a decoding matrix, where K is the size of the latent vector. However, they 
get better results in their experiments by letting all the layers of the decoder attend to the latent 
code z. An easy way to do this is to define the memory vector hm = Wz, where W c RUEX*, 
where L is the number of layers in the decoder, and then to append hm € R’*# to all the other 
embeddings at each layer. 

An alternative approach, known as transformer VAE, was proposed in [Gre20]. This model uses 
a funnel transformer [Dai+20b] as the encoder, and the T5 [Raf+20a] conditional transformer for 
the decoder. In addition, it uses an MMD VAE (Section 21.3.2.1) to avoid posterior collapse. 


21.3.6.2 Applications 


In this section, we discuss some applications of VAEs to sequence data. 


Text 


In [Bow+16b], they apply the VAE-RNN model to natural language sentences. (See also [MB16; 
SSB17] for related work.) Although this does does not improve performance in terms of the standard 
perplexity measures (predicting the next word given the previous words), it does provide a way to 
infer a semantic representation of the sentence. This can then be used for latent space interpolation, 
as discussed in Section 20.3.5. The results of doing this with the VAE-RNN are illustrated in 
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i went to the store to buy some groceries . 
i store to buy some groceries . 

he was silent for a long moment . i were to buy any groceries . 

he was silent for a moment . horses are to buy any groceries . 

it was quiet for a moment . horses are to buy any animal . 

it was dark and cold . horses the favorite any animal . 

there was a pause . horses the favorite favorite animal . 

it was my turn . horses are my favorite animal . 

(a) (b) 


Figure 21.10: (a) Samples from the latent space of a VAE text model, as we interpolate between two sentences 
(on first and last line). Note that the intermediate sentences are grammatical, and semantically related to 
their neighbors. From Table 8 of [Bow+16b]. (b) Same as (a), but now using a deterministic autoencoder 
(with the same RNN encoder and decoder). From Table 1 of [Bow+16b]. Used with kind permission of Sam 
Bowman. 


Figure 21.10a. (Similar results are shown in [Li+20], using a VAE-transformer.) By contrast, if 
we use a standard deterministic autoencoder, with the same RNN encoder and decoder networks, 
we learn a much less meaningful space, as illustrated in Figure 21.10b. The reason is that the 
deterministic autoencoder has “holes” in its latent space, which get decoded to nonsensical outputs. 

However, because RNNs (and transformers) are powerful decoders, we need to address the problem 
of posterior collapse, which we discuss in Section 21.4. One common way to avoid this problem is to 
use KL annealing, but a more effective method is to use the InfoVAE method of Section 21.3.2, which 
includes adversarial autoencoders (used in [She+20] with an RNN decoder) and MMD autoencoders 
(used in [Gre20] with a transformer decoder). 


Sketches 


=~ In [HE18], they apply the VAE-RNN model to generate sketches (line drawings) of various animals 
~ and hand-written characters. They call their model sketch-rnn. The training data records the 
~ sequence of (x,y) pen positions, as well as whether the pen was touching the paper or not. The 


~~ emission model used a GMM for the real-valued location offsets, and a categorical softmax distribution 
— for the discrete state. 


Figure 21.11 shows some samples from various class-conditional models. We vary the temperature 


— parameter 7 of the emission model to control the stochasticity of the generator. (More precisely, we 


= multiply the GMM variances by 7, and divide the discrete probabilities by 7 before renormalizing. ) 


When the temperature is low, the model tries to reconstruct the input as closely as possible. However, 


— when the input is untypical of the training set (e.g., a cat with three eyes, or a toothbrush), the 


~~ reconstruction is “regularized” towards a canonical cat with two eyes, while still keeping some features 


— of the input. 


49 Molecular design 


33 Tn [GB+18], they use VAE-RNNs to model molecular graph structure, represented as a string using 
44 the SMILES representation.* It is also possible to learn a mapping from the latent space to some 


46 4. See https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system. 
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Human Input Reconstructions 


[Be a aS 


Figure 21.11: Conditional generation of cats from sketch-RNN model. We increase the temperature parameter 
from left to right. From Figure 5 of [HE18]. Used with kind permission of David Ha. 


scalar quantity of interest, such as the solubility or drug efficacy of a molecule. We can then perform 
gradient-based optimization in the continuous latent space to try to generate new graphs which 
maximize this quantity. See Figure 21.12 for a sketch of this approach. 

The main problem is to ensure that points in latent space decode to valid strings/ molecules. There 
are various solutions to this, including using a grammar VAE, where the RNN decoder is replaced 
by a stochastic context free grammar. See [KPHL17] for details. 


21.4 Avoiding posterior collapse 


If the decoder pg(a|z) is sufficiently powerful (e.g., a pixel CNN, or an RNN for text), then the VAE 
does not need to use the latent code z for anything. This is called posterior collapse or variational 
overpruning (see e.g., [Che+17b; Ale+18; Hus17a; Phu+18; TT17; Yeu+17; Luc+19; DWW19; 
WBC21]). To see why this happens, consider Equation (21.58). If there exists a parameter setting for 
the generator 0* such that pọ» (æ|z) = pp(x) for every z, then we can make Dg (pp(2) || pe(a)) = 0. 
Since the generator is independent of the latent code, we have pg(z|x) = pe(z). The prior pe(z) is 
usually a simple distribution, such as a Gaussian, so we can find a setting of the inference parameters 
so that qg«(z|x) = pe(z), which ensures Dxz (qẹ(z|æ) || po(z|x)) = 0. Thus we have succesfully 
maximized the ELBO, but we have not learned any useful latent representation of the data, which is 
one of the goals of latent variable modeling.” We discuss some solutions to posterior collapse below. 


5. Note that [Luc+19; DWW 20] show that posterior collapse can also happen in linear VAE models, where the ELBO 
corresponds to the exact marginal likelihood, so the problem is not only due to powerful (nonlinear) decoders, but is 
also related to spurious local maxima in the objective. 
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Figure 21.12: Application of VAE-RNN to molecule design. (a) The VAE-RNN model is trained on a sequence 
representation of molecules known as SMILES. We can fit an MLP to map from the latent space to properties 
of the molecule, such as its “fitness” f(z). (b) We can perform gradient ascent in f(z) space, and then decode 
the result to a new molecule with high fitness. From Figure 1 of [GB+18]. Used with kind permission of 
Rafael Gomez-Bombarelli. 


21.4.1 KL annealing 


A common approach to solving this problem, proposed in [Bow+ 16a], is to use KL annealing, in 
which the KL penalty term in the ELBO is scaled by 6, which is increased from 0.0 (corresponding 
to an autoencoder) to 1.0 (which corresponds to standard MLE training). (Note that, by contrast, 


30 the 6-VAE model in Section 21.3.1 uses 3 > 1.) 


KL annealing can work well, but requires tuning the schedule for 8. A standard practice [Fu+19] 


32 is to use cyclical annealing, which repeats the process of increasing 8 multiple times. This ensures 
33 the progressive learning of more meaningful latent codes, by leveraging good representations learned 


in a previous cycle as a way to warmstart the optimization. 


38 21.4.2 Lower bounding the rate 


40 An alternative approach is to stick with the original unmodified ELBO objective, but to prevent the 
41 rate (ie., the Dpi (q || p) term) from collapsing to 0, by limiting the flexibility of q. For example, 
42 [XD18; Dav+18] use a von Mises-Fisher (Section 2.2.8.3) prior and posterior, instead of a Gaussian, 
43 and they constrain the posterior to have a fixed concentration, q(z|æ) = vMF(z|pu(a),«). Here 


the parameter « controls the rate of the code. The 6-VAE method [Oor+19] uses a Gaussian 


45 autoregressive prior and a diagonal Gaussian posterior. We can ensure the rate is at least ô by 
46 adjusting the regression parameter of the AR prior. 
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VAE SKIP-VAE 
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Figure 21.13: (a) VAE. (b) Skip-VAE. From Figure 1 of [Die+19a]. Used with kind permission of Adji Dieng. 


21.4.3 Free bits 


In this section, we discuss the method of free bits [Kin+16], which is another way of lower bounding 
the rate. To explain this, consider a fully factorized posterior in which the KL penalty has the form 


Lr= p3 Dru (qe (zilæ) || po(z:i)) (21.100) 


(3 


where z; is the ith dimension of z. We can replace this with a hinge loss, that will give up driving 
down the KL for dimensions that are already beneath a target compression rate A: 


Lp = Dj max(A, Dux (ag (zl) || po(2:))) (21.101) 


a 


Thus the bits where the KL is sufficiently small “are free”, since the model does not have to “pay” to 
encode them according to the prior. 


21.4.4 Adding skip connections 


One reason for latent variable collapse is that the latent variables z are not sufficiently “connected to” 
the observed data x. One simple solution is to modify the architecture of the generative model by 
adding skip connections, similar to a residual network (Section 16.2.4), as shown in Figure 21.13. 
This is called a skip-VAE [Die+19al. 


21.4.5 Improved variational inference 


The posterior collapse problem is caused in part by the poor approximation to the posterior. In 
[He+19], they proposed to keep the model and VAE objective unchanged, but to more aggressively 
update the inference network before each step of generative model fitting. This enables the inference 
network to capture the current true posterior more faithfully, which will encourage the generator to 
use the latent codes when it is useful to do so. 
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Figure 21.14: Hierarchical VAEs with 3 stochastic layers. Left: Generative model. Right: Inference network. 
Diamond is a residual network, ® is feature combination (e.g., concatenation), and h is a trainable parameter. 
We first do bottom-up inference, by propagating £ up to z3 to compute z3 ~ qẹ(z3|\x), and then we perform 
top-down inference by computing z3 ~ qẹ(z2|x£, z3) and then zi ~ q¢(Z1|@, 23.3). From Figure 2 of [VK20a]. 
Used with kind permission of Arash Vahdat. 


However, this only addresses the part of posterior collapse that is due to the amortization gap 


32 [CLD18], rather than the more fundamental problem of variational pruning, in which the KL term 
33 penalizes the model if its posterior deviates too far from the prior, which is often too simple to match 
34 the aggregated posterior. 


Another way to ameliorate variational pruning is to use lower bounds that are tighter than the 


36 vanilla ELBO (Section 10.5.1), or more accurate posterior approximations (Section 10.4), or more 
37 accurate (hierarchical) generative models (Section 21.5). 


41 21.4.6 Alternative objectives 


43 An alternative to the above methods is to replace the ELBO objective with other objectives, such as 
44 the InfoVAE objective discussed in Section 21.3.2, which includes adversarial autoencoders and MMD 
45 autoencoders as special cases. The InfoVAE objective includes a term to explicitly enforce non-zero 
46 mutual information between x and z, which effectively solves the problem of posterior collapse. 
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21.5. VAES WITH HIERARCHICAL STRUCTURE 


21.5 VAEs with hierarchical structure 


We define a hierarchical VAE or HVAE, with L stochastic layers, to be the following generative 
model:° 


1 


Po, 1:1) = pe(2z) | JI pees) po(x|zı) (21.102) 
i=L-1 


We can improve on the above model by making it non-Markovian, i.e., letting each z; depend on all 
the higher level stochastic variables, z;+1:,, not just the preceeding level, i.e., 


1 


po(x, z) = pe(2z) | II plese) po(x|21:1) (21.103) 
jot 4 


Note that the likelihood is now pg(x|z1.1) instead of just pọ(æ|zı). This is analogous to adding skip 
connections from all preceeding variables to all their children. It is easy to implement this by using 
a deterministic “backbone” of residual connections, that accumulates all stochastic decisions, and 
propagates them down the chain, as illustrated in Figure 21.14(left). We discuss how to perform 
inference and learning in such models below. 


21.5.1 Bottom-up vs top-down inference 
To perform inference in a hierarchical VAE, we could use a bottom-up inference model of the 
form 


L 


qo(z|a@) = go(z1|x) | | ao(zi\@, 211-1) (21.104) 
l=2 


However, a better approach is to use a top-down inference model of the form 


1 
qo(z|@) =do(zz|e) [| go(zile, 2141-2) (21.105) 
l=L-1 


Inference for z; combines bottom-up information from x with top-down information from higher 
layers, Z>; = 2141.1. See Figure 21.14 (right) for an illustration.’ 
With the above model, the ELBO can be written as follows (using the chain rule for KL): 


Lo,g(x) = Eq, z\a) [log pe(x|z)] — Dex (¢¢(22|2) || pe(Zz)) (21.106) 


= ‘ap(zsile) LDxx (qe (221%, 252) || po(zılz>1))] (21.107) 
l=L—1 


6. There is a split in the literature about whether to label the top level as zz, or zı. We adopt the former convention, 
since we view lower numbered layers, such as z1, as being “closer to the data”, and higher numbered layers, such as zz, 
as being “more abstract”. 

7. Note that it is also possible to have a stochastic bottom-up encoder and a stochastic top-down encoder, as discussed 
in the BIVA paper [Maa+19]. (BIVA stands for “Bidirectional-Inference Variational Autoencoder.) 
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where 


qo(2>1|@) = Il qo(2i|x, z>:) (21.108) 
i=l+1 


is the approximate posterior above layer l (i.e., the parents of z;). 
The reason the top-down inference model is better is that it more closely approximates the true 
posterior of a given layer, which is given by 


polzi|£, Zi+1:L) X po(21|2141:1) Pe (#21, 2141:L) (21.109) 


Thus the posterior combines the top-down prior term pg(2i|Zi41:,) with the bottom-up likelihood 
term pe(x|Z, Z141:,). We can approximate this posterior by defining 


Qo (2i|&, zi+1:L) X po (2i|2141:1) do (21|@, 2141:1) (21.110) 


where G¢(21|%, Z1+1:L) is a learned Gaussian approximation to the bottom-up likelihood. If both prior 
and likelihood are Gaussian, we can compute this product in closed form, as proposed in the ladder 
network paper [Sn+16; Søn+16].” A more flexible approach is to let qo(zi|@, z1+1:L) be learned, 
but to force it to share some of its parameters with the learned prior pe(zı|zı+1:L), as proposed in 
[Kin+16]. This reduces the number of parameters in the model, and ensures that the posterior and 
prior remain somewhat close. 


21.5.2 Example: Very deep VAE 


There have been many papers exploring different kinds of HVAE models (see e.g., [Kin+16; Sn+16; 
Chi2la; VK20a; Maa+19]), and we do not have space to discuss them all. Here we focus on the 
“very deep VAE” or VD-VAE model of [Chi21a], since it is simple but yields state of the art 
results (at the time of writing). 

The architecture is a simple convolutional VAE with bidrectional inference, as shown in Figure 21.15. 


2, For each layer, the prior and posterior are diagonal Gaussians. The author found that nearest-neighbor 


upsampling (in the decoder) worked much better than transposed convolution, and avoided posterior 
collapse. This enabled training with the vanilla VAE objective, without needing any of the tricks 


ae discussed in Section 21.5.4. 


The low-resolution latents (at the top of the hierarchy) capture a lot of the global structure of 
each image; the remaining high-resolution latents are just used to fill in details, that make the image 
look more realistic, and improve the likelihood. This suggests the model could be useful for lossy 
compression, since a lot of the low-level details can be drawn from the prior (i.e., “hallucinated”), 
rather than having to be sent by the encoder. 

We can also use the model for unconditional sampling at multiple resolutions. This is illustrated 


32 in Figure 21.16, using a model with 78 stochastic layers trained on the FFHQ-256 dataset.°. 


8. The term “ladder network” arises from the horizontal “rungs” in Figure 21.14(right). Note that a similar idea was 


— independently proposed in [Sal16]. 


* 9. This is a 256? version of the Flickr-Faces High Quality dataset from https://github.com/NVlabs/ffhq-dataset, 
46 which has 80k images at 1024? resoution. 
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Figure 21.15: The top-down encoder used by the hierarchical VAE in [Chi21a]. Each convolution is preceded 
by the GELU nonlinearity. The model uses average pooling and nearest-neighbor upsampling for the pool and 
unpool layers. The posterior qẹ and prior pọ are diagonal Gaussians. From Figure 3 of [Chi21a]. Used with 
kind permission of Rewon Child. 


High resolution 


Low resolution 


Figure 21.16: Samples from a VDVAE model (trained on FFHQ dataset) from different levels of the hierarchy. 
From Figure 1 of [Chi21a]. Used with kind permission of Rewon Child. 


21.5.3 Connection with autoregressive models 


Until recently, most hierarchical VAEs only had a small number of stochastic layers. Consequently 
the images they generated have not looked as good, or had as high likelihoods, as images produced 
by other models, such as the autoregressive PixelCNN model (see Section 22.3.2). However, by 
endowing VAEs with many more stochastic layers, it is possible to outperform AR models in terms of 
likelihood and sample quality, while using fewer parameters and much less computing power [Chi21a; 
VK20a; Maa+19]. 

To see why this is possible, note that we can represent any AR model as a degenerate VAE, as 
shown in Figure 21.17(left). The idea is simple: the encoder copies the input into latent space 
by setting 21:5 = %1:p (so q¢(z% = xi|Zsi, £) = 1), then the model learns an autoregressive prior 
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Latent variables are identical to observed variables Latent variables allow for parallel generation 


po(z) 
polz) qy (z|x) 
i / 7 po(x|z) 
T AER 
D ~~ i. - RRR 


Input Output Input Output 


Figure 21.17: Left: a hierarchical VAE which emulates an autoregressive model using an identify encoder, 
autoregressive prior, and identity decoder. Right: a hierarchical VAE with a 2 layer hierarchical latent code. 
The bottom hidden nodes (black) are conditionally independent given the top layer. From Figure 2 of [Chi21a]. 
Used with kind permission of Rewon Child. 


pe(Z1:p) = [a p(za|Z1:4-1), and finally the likelihood function just copies the latent vector to output 
space, so pe(x; = 2;|Z) = 1. Since the encoder computes the exact (albeit degenerate) posterior, we 
have qg(z|x) = pe(z|a), so the ELBO is tight and reduces to the log likelihood, 


log po (a) = log pe (z) = $` log pe(aala<a) (21.111) 
d 


Thus we can emulate any AR model with a VAE providing it has at least D stochastic layers, where 
D is the dimensionality of the observed data. 
In practice, data usually lives in a lower-dimensional manifold (see e.g., [DW19]), which can allow 


= for a much more compact latent code. For example, Figure 21.17(right) shows a hierarchical code 
23 in which the latent factors at the lower level are conditionally independent given the higher level, 
= and hence can be generated in parallel. Such a tree-like structure can enable sample generation in 
~ O(log D) time, whereas an autoregressive model always takes O(D) time. (Recall that for an image 
2+ D is the number of pixels, so it grows quadratically with image resolution. For example, even a tiny 
32 32x32 image has D = 3072.) 


In addition to speed, hierarchical models also require many fewer parameters than “flat” models. 


3% The typical architecture used for generating images is a multi-scale approach: the model starts from 
22 a small, spatially arranged set of latent variables, and at each subsequent layer, the spatial resolution 
3 is increased (usually by a factor of 2). This allows the high level to capture global, long-range 
31 correlations (e.g., the symmetry of a face, or overall skin tone), while letting lower levels capture 
33 fine-grained details. 


~ 21.5.4 Variational pruning 


A common problem with hierarchical VAEs is that the higher level latent layers are often ignored, so 


43 the model deos not learn interesting high level semantics. This is caused by variational pruning. 
44 This problem is analogous to the issue of latent variable collapse, which we discussed in Section 21.4. 


A common heuristic to mitigate this problem is to use KL balancing coefficients [Che+17b], to 


46 ensure that an equal amount of information is encoded in each layer. That is, we use the following 
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21.6. VECTOR QUANTIZATION VAE 


penalty: 


L 
do HE gg(zs112) [Dex (ge(2il@, 251) || pe (2112>0))] (21.112) 
l=1 


The balancing term y is set to a small value when the KL penalty is small (on the current minibatch), 
to encourage use of that layer, and is set to a large value when the KL term is large. (This is only 
done during the “warmup period”.) Concretely, [VK20a] proposes to set the coefficients yı to be 
proportional to the size of the layer, s;, and the average KL loss: 


y X s[Ee~e [Eg,(25,|2) [Deu (qe (zilæ, 251) || pe(zi|z>2))]] (21.113) 


where B is the current minibatch. 


21.5.5 Other optimization difficulties 


A common problem when training (hierarchical) VAEs is that the loss can become unstable. The 
main reason for this is that the KL term is unbounded (can become infinitely large). In [Chi21a], they 
tackle the problem in two ways. First, ensure the initial random weights of the final convolutional 
layer in each residual bottleneck block get scaled by 1/ VL. Second, skip an update step if the norm 
of the gradient of the loss exceeds some threshold. 

In the Nouveau VAE method of [VK20a], they use some more complicated measures to ensure 
stability. First they use batch normalization, but with various tweaks. Second they use spectral 
regularization for the encoder. Specifically they add the penalty 87, Ai, where A; is the largest 
singular value of the ith convolutional layer (estimated using a single power iteration step), and 
8 > 0 is a tuning parameter. Third, they use inverse autoregressive flows (Section 23.2.4.3) in each 
layer, instead of a diagonal Gaussian approximation. Fourth, they represent the posterior using a 
residual representation. In particular, let us assume the prior for the ith variable in layer l is 


po(zj|%>1) = N(zj|Mi(2>1), oi(2>1)) (21.114) 
They propose the following posterior approximation: 
ao (zi|@, 21) = N (zilui(z>1) + Api(z51, £), oi(z>1) - Aoi(Z51, 2) (21.115) 


where the A terms are the relative changes computed by the encoder. The corresponding KL penalty 
reduces to the following (dropping the l subscript for brevity): 


i : 1 (Au? 
DgL (qol |z, z>1) | pe(z'|Zs1)) = 5) ( oy H Ao? log Aa? — i) (21.116) 


So as long as g; is bounded from below, the KL term can be easily controlled just by adjusting the 
encoder parameters. 


21.6 Vector quantization VAE 


In this section, we describe VQ-VAE, which stands for “vector quantized VAE” [OVK17; ROV19]. 
This is like a standard VAE except it uses a set of discrete latent variables. 
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Figure 21.18: Autoencoder for MNIST using 256 binary latents. Top row: input images. Middle 
row: reconstruction. Bottom row: latent code, reshaped to a 16 x 16 image. Generated by quan- 
tized _autoencoder_ mnist.ipynb. 


21.6.1 Autoencoder with binary code 


The simplest approach to the problem is to construct a standard VAE, but to add a discretization 
layer at the end of the encoder, z.(a) € {0,...,9 — 1}*, where S is the number of states, and K 
is the number of discrete latents. For example, we can binarize the latent vector (using S = 2) by 
clipping z to lie in {0,1}*. This can be useful for data compression (see e.g., [BLS17]). 

Suppose we assume the prior over the latent codes is uniform. Since the encoder is deterministic, 
the KL divergence reduces to a constant, equal to log K. This avoids the problem with posterior 
collapse (Section 21.4). Unfortunately, the discontinuous quantization operation of the encoder 
prohibits the direct use of gradient based optimization. The solution proposed in [OVK17] is to use 
the straight-through estimator, which we discuss in Section 6.5.8. We show a simple example of this 
approach in Figure 21.18, where we use a Gaussian likelihood, so the loss function has the form 


L = ||æ — d(e(x))||z (21.117) 


where e(a) € {0,1}* is the encoder, and d(z) € R78*?® is the decoder. 


21.6.2 VQ-VAE model 


We can get a more expressive model by using a 3d tensor of discrete latents, z € R?*”**, where 


K is the number of discrete values per latent variable. Rather than just binarizing the continuous 


vector Ze(®)ij, we compare it to a codebook of embedding vectors, {ex : k =1: K, ep € R? }, and 


z- then set 2;; to the index of the nearest codebook entry: 


1 if k =argmin,, ||ze(@)i,;,: — ex’ ||2 (21.118) 
0 otherwise 


q(2ij = hla) = 


When reconstructing the input we replace each discrete code index by the corresponding real-valued 


= codebook vector: 


(zalij = €x Where zij = k (21.119) 
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Figure 21.19: VQ-VAE architecture. From Figure 1 of [OVK17]. Used with kind permission of Aaron van 
den Oord. 


These values are then passed to the decoder, p(x|z,), as usual. See Figure 21.19 for an illustration of 
the overall architecture. Note that although z, is generated from a discrete combination of codebook 
vectors, the use of a distributed code makes the model very expressive. For example, if we use a 
grid of 32 x 32, with K = 512, then we can generate 51232?x32 = 29216 distinct images, which is 
astronomically large. 

To fit this model, we can minimize the negative log likelihood (reconstruction error) using the 
straight-through estimator, as before. This amounts to passing the gradients from the decoder input 
Zq() to the encoder output ze(x), bypassing Equation (21.118), as shown by the red arrow in 
Figure 21.19. Unfortunately this means that the codebook entries will not get any learning signal. 
To solve this, the authors proposed to add an extra term to the loss, known as the codebook loss, 
that encourages the codebook entries e to match the output of the encoder. We treat the encoder 
z-(x) as a fixed target, by adding a stop gradient operator to it; this ensures Ze is treated normally 
in the forwards pass, but has zero gradient in the backwards pass. The modified loss (dropping the 
spatial indices 7,7) becomes 


L = — log p(x|zq(x)) + ||se(ze(x)) — ellz (21.120) 


where e refers to the codebook vector assigned to ze(x), and sg is the stop gradient operator. 

An alternative way to update the codebook vectors is to use moving averages. To see how this 
works, first consider the batch setting. Let {z;.1,..., Zin,} be the set of n; outputs from the encoder 
that are closest to the dictionary item e;. We can update e; to minimize the MSE 


So leis — e:ll? (21.121) 
j=l 

which has the closed form update 
ei = — ò Zij (21.122) 
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This is like the M step of the EM algorithm when fitting the mean vectors of a GMM. In the minibatch 
setting, we replace the above operations with an exponentially moving average, as follows: 


Ni = NS + (L—9)ny (21.123) 
m; =m; + (1-7) Diy (21.124) 
j 
t 
mi 
dan (21.125) 


The authors found y = 0.9 to work well. 

The above procedure will learn to update the codebook vectors so it matches the output of the 
encoder. However, it is also important to ensure the encoder does not “change its mind” too often 
about what codebook value to use. To prevent this, the authors propose to add a third term to 
the loss, known as the commitment loss, that encourages the encoder output to be close to the 
codebook values. Thus we get the final loss: 


L = — log p(æ|za(£)) + ||sg(ze(£)) — ell3 + llze(æ) — se(e)||3 (21.126) 


The authors found 6 = 0.25 to work well, although of course the value depends on the scale of the 
reconstruction loss (NLL) term. (A probabilistic interpretation of this loss can be found in [Hen+18].) 
Overall, the decoder optimizes the first term only, the encoder optimizes the first and last term, and 
the embeddings optimize the middle term. 


21.6.3 Learning the prior 


28 After training the VQ-VAE model, it is possible to learn a better prior, to match the aggregated 


posterior. To do this, we just apply the encoder to a set of data, {£n}, thus converting them to 


28 discrete sequences, {zn}. We can then learn a joint distribution p(z) using any kind of sequence 
22 model. In the original VQ-VAE paper [OVK17], they used the causal convolutional PixelCNN model 
32 (Section 22.3.2). More recent work has used transformer decoders (Section 22.4). Samples from this 
21 prior can then be decoded using the decoder part of the VQ-VAE model. We give some examples of 


22 this in the sections below. 


A 
an 


~ 21.6.4 Hierarchical extension (VQ-VAE-2) 


36 In [ROV19], they extend the original VQ-VAE model by using a hierarchical latent code. The model 
37 is illustrated in Figure 21.20. They applied this to images of size 256 x 256 x 3. The first latent layer 
38 maps this to a quantized representation of size 64 x 64, and the second latent layer maps this to a 
39 quantized representation of size 32 x 32. This hierarchical scheme allows the top level to focus on 
40 high level semantics of the image, leaving fine visual details, such as texture, to the lower level. (See 
41 Section 21.5 for more discussion of hierarchical VAEs.) 


After fitting the VQ-VAE, they learn a prior over the top level code using a PixelCNN model 


43 augmented with self-attention (Section 16.2.7) to capture long-range dependencies. (This hybrid 


model is known as PixelSNAIL [Che+17c].) For the lower level prior, they just use standard PixelCNN, 
since attention would be too expensive. Samples from the model can then be decoded using the 


46 VQ-VAE decoder, as shown in Figure 21.20. 
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VQ-VAE Encoder and Decoder Training Image Generation 


Encoder | Decoder Condition 


Original Reconstruction Generation 
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Figure 21.20: Hierarchical extension of VQ-VAE. (a) Encoder and decoder architecture. (b) Combining a 
Pizel-CNN prior with the decoder. From Figure 2 of [ROV19]. Used with kind permission of Aaron van den 
Oord. 
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Figure 21.21: Illustration of the Gumbel Softmaz trick applied to K = 4 codebook vectors in L = 2 dimensions. 
From https: //ml. berkeley. edu/blog/posts/dalle2/. Used with kind permission of Charlie Snell. 


21.6.5 Discrete VAE 


In VQ-VAE, we use a one-hot encoding for the latents, q(z = k|æ) = 1 iff k = argmin,, ||z-(a) — ex||2, 
and then set 2, = ex. This does not capture any uncertainty in the latent code, and requires the use 
of the straight-through estimator for training. 

Various other approaches to fitting VAEs with discrete latent codes have been investigated. In the 
DALL-E paper (Section 22.4.3), they use a fairly simple method, based on using the Gumbel-Softmax 
relaxation for the discrete variables (see Section 6.5.6). In brief, let q(z = k|a) be the probability 
that the input æ is assigned to codebook entry k. We can exactly sample wp ~ q(z = kla) from this 
by computing wą = argmax, gpk + log q(z = k|æ), where each g;, is from a Gumbel distribution. We 
can now “relax” this by using a softmax with temperature 7 > 0 and computing 


exp( gk +log ate=k|e) ) 


2 exp( gj tog q(z=J|x) ) 


T 


Wk = (21.127) 
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Figure 21.22: Illustration of the VQ-GAN. From Figure 2 of [ERO21]. Used with kind permission of Patrick 
Esser. 


We now set the latent code to be a weighted sum of the codebook vectors: 


K 
Zq = > Weer (21.128) 
k=1 


In the limit that 7 — 0, the distribution over weights w converges to a one-hot disribution, in which 
case z becomes equal to one of the codebook entries. But for finite 7, we “fill in” the space between 
the vectors, as illustrated in Figure 21.21. 

This allows us to express the ELBO in the usual differentiable way: 


L = —E,(z\2) [log p(a|z)] + 6Dux (4(2|@) || p(z)) (21.129) 


£ where 3 > 0 controls the amount of regularization. (Unlike VQ-VAE, the KL term is not a constant, 


because the encoder is stochastic.) Furthermore, since the Gumbel noise variables are sampled from 
a distribution that is independent of the encoder parameters, we can use the reparameterization trick 
(Section 6.5.4) to optimize this. 


35 21.6.6 VQ-GAN 


One drawback of VQ-VAE is that it uses mean squared error in its reconstruction loss, which can 


2 result in blurry samples. In the VQ-GAN paper [ERO21], they replace this with a (patch-wise) 
32 GAN loss (see Chapter 26), together with a perceptual loss; this results in much higher visual fidelity. 
22 In addition, they use a transformer (see Section 16.3.5) to model the prior on the latent codes. See 


Figure 21.22 for a visualization of the overall model. In [Yu+21], they replace the CNN encoder and 
decoder of the VQ-GAN model with transformers, yielding improved results; they call this VIM 


22 (Vector-quantized Image Modeling). 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


22 Auto-regressive models 


22.1 Introduction 


By the chain rule of probability, we can write any joint distribution over T variables as follows: 


T 
P(@1-7) = p(aw1)p(#2|a1)p(ws|e2, @1)p(w4|e3,@2,01)-.. = | | p(ae|ais—s) (22.1) 


where x; € X is the t’th observation, and we define p(a1|x1.9) = p(#1) as the initial state distribution. 
This is called an auto-regressive model. This corresponds to a fully connected DAG, in which 
each node depends on all its predecessors in the ordering, as shown in Figure 22.1. The models can 
also be conditioned on arbitrary inputs or context c, in order to define p(x|c), although we omit this 
for notational brevity. 

We could of course also factorize the joint distribution “backwards” in time, using 


1 
P(£1-7) = |[ v1 p(@+|@141:7) (22.2) 
t=T 


However, this “anti-causal” direction is often harder to learn (see e.g., [PJS17]). 

Although the decomposition in Equation (22.1) is general, each term in this expression (i.e., each 
conditional distribution p(x;|x%14—1)) becomes more and more complex, since it depends on an 
increasing number of arguments, which makes the terms slow to compute, and makes estimating 
their parameters more data hungry (see Section 2.6.3.2). 

One approach to solving this intractability is to make the (first-order) Markov assumption, 
which gives rise to a Markov model p(x;|%14~-1) = p(a1|a4~-1), which we discuss in Section 2.6. 
(This is also called an auto-regressive model of order 1.) Unfortunately, the Markov assumption is 
very limiting. One way to relax it, and to make x, depend on all the past £1:+—ı without explicitly 
regressing on them, is to assume the past can be compressed into a hidden state z,. If z is a 
deterministic function of the past observations x1.4-1, the resulting model is known as a recurrent 
neural network, discussed in Section 16.3.4. If z; is a stochastic function of the past hidden state, 
Z:-1, the resulting model is known as a hidden Markov model, which we discuss in Section 29.2. 

Another approach is to stay with the general AR model of Equation (22.1), but to use a restricted 
functional form, such as some kind of neural network, for the conditionals p(æ+|£1:+—-1). Thus rather 
than making conditional independence assumptions, or explicitly compressing the past into a sufficient 
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Figure 22.1: A fully-connected auto-regressive model. 


statistic, we implicitly learn a compact mapping from the past to the future. In the sections below, 
we discuss different functional forms for these conditional distributions. 

The main advantage of such AR models is that it is easy to compute, and optimize, the exact 
likelihood of each sequence (data vector). The main disadvantage is that generating samples is 
inherently sequential, which can be slow. In addition, the method does not learn a compact latent 


18 representation of the data. 


~ 22.2 Neural autoregressive density estimators (NADE) 


= A simple way to represent each conditional probability distribution p(x+|£1:+—1) is to use a generalized 
= linear model, such as logistic regression, as proposed in [Fre98]. We can make the model be more 
= powerful by using a neural network. The resulting model is called the neural auto-regressive 
= density estimator or NADE model [LM1]]. 


If we let p(a;|a1-4-1) be a conditional mixture of Gaussians, we get a model known as RNADE 
(“Real-valued Neural Autoregressive Density Estimator”) of [UML13]. More precisely, this has the 
= form 
K 
P(@el@at—1) = $ TaN (Trlr C7 4) (22.3) 
k=1 


o, Where the parameters are generated by a network, (H, Ot, mt) = ftl£1:t-1; 04). 


Rather than using separate neural networks, f1,..., fr, it is more efficient to create a single 
network with T inputs and T outputs. This can be done using masking, resulting in a model called 


,, the MADE (“Masked Autoencoder for Density Estimation”) model [Ger+15]. 


One disadvantage of NADE-type models is that they assume the variables have a natural linear 


za ordering. This makes sense for temporal or sequential data, but not for more general data types, 


such as images or graphs. An orderless extension to NADE was proposed in [UML14; Uri+16]. 


42 22.3 Causal CNNs 


44 One approach to representing the distribution p(a;|a1.,_1) is to try to identify patterns in the past 
45 history that might be predictive of the value of x;. If we assume these patterns can occur in any 
46 location, it makes sense to use a convolutional neural network to detect them. However, we need 
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Figure 22.2: Illustration of the wavenet model using dilated (atrous) convolutions, with dilation factors of 1, 
2, 4 and 8. From Figure 3 of [oor+16]. Used with kind permission of Aaron van den Oord. 


to make sure we only apply the convolutional mask to past inputs, not future ones. This can be done 
using masked convolution, also called causal convolution. We discuss this in more detail below. 


22.3.1 1d causal CNN (Convolutional Markov models) 
Consider the following convolutional Markov model for 1d discrete sequences: 


T T t—k 
pwr) = |] p(welaie1; 6) = | | Cat(2:|softmax(y(S° w"ax,:-+4))) (22.4) 


t=1 t=1 


where w is the convolutional filter of size k, and we have assumed a single nonlinearity y and 
categorical output, for notational simplicity. This is like regular 1d convolution except we “mask out” 
future inputs, so that x, only depends on the past values. We can of course use deeper models, and 
we can condition on input features c. 

In order to capture long-range dependencies, we can use dilated convolution (see [Mur22, Sec 
14.4.1]). This model has been successfully used to create a state of the art text to speech (TTS) 
synthesis system known as wavenet [oor+16]. See Figure 22.2 for an illustration. 

The wavenet model is a conditional model, p(ax|c), where c is a set of linguistic features derived 
from an input sequence of words, and æ is raw audio. The tacotron system [Wan-+17c] is a fully 
end-to-end approach, where the input is words rather than linguistic features. 

Although wavenet produces high quality speech, it is too slow for use in production systems. How- 
ever, it can be “distilled” into a parallel generative model [Oor-+18], as we discuss in Section 23.2.4.3. 


22.3.2 2d causal CNN (PixelCNN) 
We can extend causal convolutions to 2d, to get an autoregressive model of the form 
R © 
p(x|0) = II II pte c| fo(£i:r-1,1:0; Trae=i)) (22.5) 
r=lc=1 


where R is the number of rows, C is the number of columns, and we condition on all previously 
generated pixels in a raster scan order, as illustrated in Figure 22.3. This is called the pixelCNN 
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Figure 22.8: Illustration of causal 2d convolution in the PixelCNN model. The red histogram shows the 
empirical distribution over discretized values for a single pizel of a single RGB channel. The red and green 
5 x 5 array shows the binary mask, which selects the top left context, in order to ensure the convolution is 
causal. The diagrams on the right illustrate how we can avoid blind spots by using a vertical context stack, 
that contains all previous rows, and a horizontal context stack, that just contains values from the current row. 
From Figure 1 of [Oor+16]. Used with kind permission of Aaron van den Oord. 


model [Oor+16]. Naive sampling (generation) from this model takes O(N) time, where N = RC is 
the number of pixels, but [Ree-+17] shows how to use a multiscale approach to reduce the complexity 
to O(log N). 

Various extensions of this model have been proposed. The pixelCNN++ model of [Sal+17c] 
improved the quality by using a mixture of logistic distributions, to capture the multimodality 
of p(x;|@14-1). The pixelIRNN of [OKK16] combined masked convolution with an RNN to get 
even longer range contextual dependencies. The Subscale Pixel Network of [MK19] proposed to 
generate the pixels such that the higher order bits are sampled before lower order bits, which allows 


z high resolution details to be sampled conditioned on low resolution versions of the whole image, 
5g rather than just the top left corner. 


30 22.4 Transformer decoders 


32 We introduced transformers in Section 16.3.5. They can be used for encoding sequences (as in BERT), 
33 or for decoding (generating) sequences. In this section, we focus on the latter case. 


The basic idea is as follows. At each step t, the model applies masked (causal) self attention 


35 (Section 16.2.7) to the first t inputs, y1.,, to compute a set of attention weights, a1.;. From this it 
36 computes an activation vector z; = es aryt. This is then passed through a feed-forward layer to 


compute h; = MLP (z+). This process is repeated for each layer in the model. Finally the output is 


38 used to predict the next element in the sequence, yz+1 ~ Cat(softmax(Wh;,)). (In the conditional 
39 generation setting, where we want to compute p(y|x), we can just treat the first x tokens as part of 
40 the initial output sequence. There is no need to use an encoder block.) 


At training time, all predictions can happen in parallel, since the target generated sequence is 


42 already available. That is, the tth output ys can be predicted given inputs y1:+—1, and this can be 
43 done for all t simultaneously. However, at test time, the model must be applied sequentially, so the 


output generated at t+ 1 is fed back into the model to predict t + 2, etc. Note that the running time 


45 of transformers is O(T7), although a variety of more efficient versions have been developed (see e.g., 
46 [Mur22, Sec 15.6] for details). 
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22.4. TRANSFORMER DECODERS 


PROMPT: 


In a shocking finding, scientist discovered a herd of unicorns living in a remote, 
previously unexplored valley, in the Andes Mountains. Even more surprising to the 
researchers was the fact that the unicorns spoke perfect English. 


RESPONSE: 


The scientist named the population, after their distinctive horn, Ovid’s Unicorn. 

These four-horned, silver-white unicorns were previously unknown to science. 

Now, after almost two centuries, the mystery of what sparked this odd phenomenon 

is finally solved. Dr. Jorge Pérez, an evolutionary biologist from the University of La Paz, 
and several companions, were exploring the Andes Mountains when they found a small valley, 
with no other animals or humans.... 


Figure 22.4: Sample text generated by GPT-2 in response to an input prompt. From https: // openai. com/ 
blog/ better-language-models/. 


Transformers are the basis of many popular (conditional) generative models for sequences. We 
give some examples below. 


22.4.1 Text generation (GPT) 


In [Rad+18], OpenAI proposed a model called GPT, which is short for “Generative Pre-training 
Transformer”. This is a decoder-only transformer model that uses causal (masked) attention. In 
[Rad+19], they propose GPT-2, which is a larger version of GPT (1.5 billion parameters, or 
6.5GB, for the XL version), trained on a large web corpus (8 million pages, or 40GB). They 
also simplify the training objective, and just train it using maximum likelihood. The fluency of 
text generated by GPT-2 is quite remarkable; see Figure 22.4 for an example. See also https: 
//demo.allennlp.org/next-token-1m, which lets you interact with the (medium sized) model, and 
generates the K most likely sequences (computed using beam search) given some input context. 

More recently, OpenAI released GPT-3 [Bro+20d], which is an even larger version of GPT-2 (175 
billion parameters), trained on even more data (300 billion words), but based on the same principles. 
(Training was estimated to take 355 GPU years and cost $4.6M.) Due to the large size of the data 
and model, GPT-3 shows even more remarkable abilities to generate novel text. In particular, the 
output can be (partially) controlled by just changing the conditioning prompt. This enables the 
model to perform tasks that it has never been trained on, just by giving it some examples in the 
prompt. This is called “in-context learning” (see Section 19.5.1.2). See Figure 22.5 for an example, 
and https: //gpt3demo.com/apps/openai-gpt-3-playground for an interactive demo. 


22.4.2 Music generation 


It is possible to modify transformer decoders so that they generate music instead of natural language, 
as shown by the music transformer paper [Hua+18a]. The key “trick” is to note that the midi 
format for music can be represented as a sequence of parameterized tokens, as shown in Figure 22.6. 
To cope with the long sequence length, a relative attention mechanism was devised. See Figure 22.7 
for a visualization. To best appreciate the quality of the generated output, please see the interactive 
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A "whatpu" is a small, furry animal native to Tanzania. An example of a sentence that uses 
the word whatpu is: 
We were traveling in Africa and we saw these very cute whatpus. 


To do a "farduddle" means to jump up and down really fast. An example of a sentence that uses 
the word farduddle is: 

One day when I was playing tag with my little sister, she got really excited and she 
started doing these crazy farduddles. 

A "yalubalu" is a type of vegetable that looks like a big pumpkin. An example of a sentence 
that uses the word yalubalu is: 

I was on a trip to Africa and I tried this yalubalu vegetable that was grown in a garden 
there. It was delicious. 


A "Burringo" is a car with very fast acceleration. An example of a sentence that uses the 
word Burringo is: 
In our garage we have a Burringo that my father drives to work every day. 


A "Gigamuru" is a type of Japanese musical instrument. An example of a sentence that uses the 
word Gigamuru is: 
I have a Gigamuru that my uncle gave me as a gift. I love to play it at home. 


To "screeg" something is to swing a sword at it. An example of a sentence that uses the word 
screeg is: 
We screeghed at each other for several minutes and then we went outside and ate ice cream. 


Figure 22.5: Illustration of few shot learning with GPT-3. The model is asked to create an example sentence 
using a new word whose meaning is provided in the prompt. Boldface is GPT-3’s completions, light gray is 
human input. From Figure 3.16 of [Bro+20d]. 


SET_VELOCITY<80>, NOTE_ON<60> 
TIME_SHIFT<500>, NOTE_ON<64> 

TIME_SHIFT<500>, NOTE_ON<67> 

TIME_SHIFT<1000>, NOTE_OFF<60>, NOTE_OFF<64>, 
NOTE_OFF<67> 

TIME_SHIFT<500>, SET_VELOCITY<100>, NOTE_ON<65> 
TIME _SHIFT<500>, NOTE_OFF<65> 


Figure 22.6: A snippet of a piano performance visualized as a pianoroll (left) and encoded as performance 
events (right, serialized from left to right and then down the rows). There are 128 discrete values for note 
on/off, 82 values for velocity, and 100 for time shift, so the input is a sequence of one-hot vectors of length 


~- 988. From Figure 7 of [Hua+18a]. Used with kind permission of Anna Huang. 


38 demo at https: //magenta.tensorflow. org/music-transformer. 


= 22.4.3 Text-to-image generation (DALL-E) 


42 The DALL-E model! from OpenAI [Ram+21a] can generate images of remarkable quality and 
43 diversity given text prompts, as shown in Figure 22.8. The methodology is conceptually quite 
44 straightforward, and most of the effort went into data collection (they scrape the web for 250 million 


46 1. The name is derived from the artist Salvador Dali and Pixar’s movied “WALL-E” 
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Figure 22.7: Illustration of attention in the music transformer. Different colored lines correspond to the 6 
attention heads. Line thickness corresponds to attention weights. From Figure 8 of [Hua+18a]. Used with 
kind permission of Anna Huang. 


(a) an armchair in the shape of an avo- (b) an illustration of a baby hedgehog in 
cado. a christmas sweater walking a dog 


Figure 22.8: Some images generated by the DALL-E model in response to a text prompt. (a) “An armchair in 
the shape of an avocado”. (b) “An illustration of a baby hedgehog in a christmas sweater walking a dog”. From 
https: // openai. com/blog/dall-e. Used with kind permission of Aditya Ramesh. 


image-text pairs) and scaling up the training (they fit a model with 12 billion parameters). Here we 
just focus on the algorithmic methods. 

The basic idea is to transform an image æ into a sequence of discrete tokens z using a discrete 
VAE model (Section 21.6.5), which defines a model of the form p(x, z). We then fit a transformer to 
the concatentation of the image tokens z and text tokens y to get a model of the form p(z, y). 

To sample an image æ given a text prompt y, we sample a latent code z ~ p(z|y), and then we 
feed z into the VAE decoder to get x ~ p(x|z). Multiple images are generated for each prompt, and 
these are then ranked according to a pre-trained critic, which gives them scores depending on how 
well the generated image matches the input text: Sn = critic(£n, yn). The critic they used was the 
contrastive CLIP model (see Section 32.3.4.1). This discriminative reranking significantly improves 
the results. 

Some sample results are shown in Figure 22.8, and more can be found online at https://openai. 
com/blog/dall-e/. The image on the right of Figure 22.8 is particularly interesting, since the 
prompt — “An illustration of a baby hedgehog in a christmas sweater walking a dog” — arguably 
requires that the model solve the “variable binding problem”. This refers to the fact that the 
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sentence implies the hedgehog should be wearing the sweater and not the dog. We see that the model 
sometimes interprets this correctly, but not always: sometimes it draws both animals with Christmas 
sweaters. In addition, sometimes it draws a hedehog walking a smaller hedgehog. The quality of the 
results can also be sensitive to the form of the prompt. Thus although impressive, this technology is 
clearly not yet reliable. 

Since being released in January 2021, many alternatives to DALL-E have been proposed. For 
example, Google released parti [Yu+22], which is like DALL-E, but uses a ViT-VQ-GAN encoder 
[Yu+21], instead of a VQ-VAE encoder; XMC-GAN [Zha+21b], which uses a GAN (Chapter 26) 
instead of a transformer; and Imagen [Sah+22], which uses a diffusion model (Chapter 25) instead 
of the transformer. OpenAI also released a diffusion-based model called GLIDE [Nic+21], and then 
later a better diffusion model called DALL-E 2 [Ram+22]. 
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2 3 Normalizing Flows 


This chapter was written by George Papamakarios and Balaji Lakshminarayanan. 


23.1 Introduction 


In this chapter we discuss normalizing flows, a class of flexible density models that can be 
easily sampled from and whose exact likelihood function is efficient to compute. Such models 
can be used for many tasks, such as density modeling, inference and generative modeling. We 
introduce the key principles of normalizing flows and refer to recent surveys by Papamakarios et al. 
[Pap+19] and Kobyzev, Prince, and Brubaker [KPB19] for readers interested in learning more. See 
also https://github.com/janosh/awesome-normalizing-flows for a list of papers and software 
packages. 


23.1.1 Preliminaries 


Normalizing flows create complex probability distributions p(x) by passing random variables u € RP, 
drawn from a simple base distribution p(w) through a nonlinear but invertible transformation 
f : RP —> RP. That is, p(x) is defined by the following process: 


x = f(u) where u~p(u). (23.1) 


The base distribution is typically chosen to be simple, for example standard Gaussian or uniform, so 
that we can easily sample from it and compute the density p(u). A flexible enough transformation 
f can induce a complex distribution on the transformed variable x even if the base distribution is 
simple. 

Sampling from p(æ) is straightforward: we first sample u from p(u) and then compute æ = f (u). 
To compute the density p(x), we rely on the fact that f is invertible. Let g(x) = f7 +(x) = u be 
the inverse mapping, which “normalizes” the data distribution by mapping it back to the base 
distribution (which is often a Normal distribution). Using the change-of-variables formula for random 
variables from Equation (2.286), we have 


Px () = pu(g(x))| det I(g)(x)| = pu(w)| det I(F)(w)|~", (23.2) 


where J(f)(u) = fle is the Jacobian matrix of f evaluated at u. Taking logs of both sides of 
Equation (23.2), we get 


log px(#) = log pu (u) — log | det J( F) (u)]. (23.3) 
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As discussed above, p(w) is typically easy to evaluate. So, if one can use flexible invertible transforma- 
tions f whose Jacobian determinant det J(f)(u) can be computed efficiently, then one can construct 
complex densities p(a) that allow exact sampling and efficient exact likelihood computation. This is 
in contrast to latent variable models, which require methods like variational inference to lower-bound 
the likelihood. 

One might wonder how flexible are the densities p(x) obtained by transforming random variables 
sampled from simple p(w). It turns out that we can use this method to approximate any smooth 
distribution. To see this, consider the scenario where the base distribution p(w) is a one-dimensional 
uniform distribution. Recall that inverse transform sampling (Section 11.3.1) samples random 
variables from a uniform distribution and transforms them using the inverse Cumulative Distribution 
Function (CDF) to generate samples from the desired density. We can use this method to sample 
from any one-dimensional density as long as the transformation f is powerful enough to model the 
inverse CDF (which is a reasonable assumption for well-behaved densities whose CDF is invertible 
and differentiable). We can further extend this argument to multiple dimensions by first expressing 
the density p(x) as a product of one-dimensional conditionals using the chain rule of probability, 
and then applying inverse transform sampling to each one-dimensional conditional. The result is a 
normalizing flow that transforms a product of uniform distributions into any desired distribution 
p(x). We refer to [Pap+19] for a more detailed proof. 

How do we define flexible invertible mappings whose Jacobian determinant is easy to compute? 
We discuss this topic in detail in Section 23.2, but in summary, there are two main ways. The first 
approach is to define a set of simple transformations that are invertible by design, and whose Jacobian 
determinant is easy to compute; for instance, if the Jacobian is a triangular matrix, its determinant 
can be computed efficiently. The second approach is to exploit the fact that a composition of invertible 
functions is also invertible, and the overall Jacobian determinant is just the product of the individual 
Jacobian determinants. More precisely, if f = fy o -- -o fı where each f; is invertible, then f is also 


invertible, with inverse g = gı 0---o gy and log Jacobian determinant given by 
N 
log | det J(g)(a)| =} log | det J(g:)(w:)| (23.4) 
i=1 


33 where u; = f;0--- o filu) is the ith intermediate output of the flow. This allows us to create 
34 complex flows from simple components, just as graphical models allow us to create complex joint 
35 distributions from simpler conditional distributions. 


Finally, a note on terminology. An invertible transformation is also known as a bijection. A 


37 bijection that is differentiable and has a differentiable inverse is known as a diffeomorphism. The 
38 transformation f of a flow model is a diffeomorphism, although in the rest of this chapter we will refer 
39 to it as a “bijection” for simplicity, leaving the differentiability implicit. The density p,(a) of a flow 
40 model is also known as the pushforward of the base distribution p,,(w) through the transformation 
41 f, and is sometimes denoted as py = fxpy. Finally, in mathematics the term “flow” refers to any 
42 family of diffeomorphisms f indexed by a real number t such that t = 0 indexes the identity function, 
43 and tı + tg indexes ft, o fe (in physics, t often represents time). In machine learning we use the term 
44 “flow” by analogy to the above meaning, to highlight the fact that we can create flexible invertible 
45 transformations by composing simpler ones; in this sense, the index t is analogous to the number 7 of 
46 transformations in f;0o---o fi. 
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23.1.2 How to train a flow model 


There are two common applications of normalizing flows. The first one is density estimation of 
observed data, which is achieved by fitting pọ(x) to the data and using it as an estimate of the 
data density, potentially followed by generating new data from pg(x). The second one is variational 
inference, which involves sampling from and evaluating a variational posterior gg(z|a) parameterized 
by the flow model. As we will see below, these applications optimize different objectives and impose 
different computational constraints on the flow model. 


23.1.2.1 Density estimation 


Density estimation requires maximizing the likelihood function in Equation (23.2). This requires that 
we can efficiently evaluate the inverse flow u = f~'(a) and its Jacobian determinant det J(f~!)(x) 
for any given æ. After optimizing the model, we can optionally use it to generate new data. To 
sample new points, we require that the forward mapping f be tractable. 


23.1.2.2 Variational inference 


Normalizing flows are commonly used for variational inference to parametrize the approximate 
posterior distribution in latent variable models, as discussed in Section 10.4.3. Consider a latent 
variable model with continuous latent variables z and observable variables a. For simplicity, we 
consider the model parameters to be fixed as we are interested in approximating the true posterior 
p*(z|a) with a normalizing flow qọ(z|æ).! As discussed in Section 10.1.2, the variational parameters 
are trained by maximizing the evidence lower bound (ELBO), given by 


L(9) = Eqo(z|a) [log p(#|z) + log p(z) — log qo (z|æ)] (23.5) 


When viewing the ELBO as a function of 0, it can be simplified as follows (note we drop the 
dependency on æ for simplicity): 


L(@) = “go (z) [Co(z)] - (23.6) 


Let qe(z) denote a normalizing flow with base distribution q(w) and transformation z = felu). Then 
the reparametrization trick (Section 6.5.4) allows us to optimize the parameters using stochastic 
gradients. To achieve this, we first write the expectation with respect to the base distribution: 


L(9) = Eqo(z) lo (2)] = Equ) llo (fo (u))] . (23.7) 


Then, since the base distribution does not depend on 0, we can obtain stochastic gradients as follows: 


N 
VoL(8) = Escu) [Volo(fo(u))] © < YT Votolfolun)), (23.8) 


n=1 
where {un} are samples from q(u). 


1. We denote the parameters of the variational posterior by @ here, which should not be confused with the model 
parameters which are also typically denoted by @ elsewhere. 
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As we can see, in order to optimize this objective, we need to be able to efficiently sample 
from qe(z|x) and evaluate the probability density of these samples during optimization. (See 
Section 23.2.4.3 for details on how to do this.) This is contrast to the MLE approach in Section 23.1.2.1, 
which requires that we be able to compute efficiently the density of arbitrary training datapoints, 
but it does not require samples during optimization. 


23.2 Constructing Flows 


In this section, we discuss how to compute various kinds of flows that are invertible by design and 
have efficiently computable Jacobian determinants. 


23.2.1 Affine flows 


A simple choice is to use an affine transformation « = f(u) = Au +b. This is a bijection if and 
only if A is an invertible square matrix. The Jacobian determinant of f is det A, and its inverse is 
u=f—'(a) = A~!(a@—b). A flow consisting of affine bijections is called an affine flow, or a linear 
flow if we ignore b. 

On their own, affine flows are limited in their expressive power. For example, suppose the base 
distribution is Gaussian, p(w) = N (u|u, =). Then the pushforward distribution after an affine 
bijection is still Gaussian, p(x) = N(a2|Au+b, AXA"™). However, affine bijections are useful building 
blocks when composed with the non-affine bijections we discuss later, as they encourage “mixing” of 
dimensions through the flow. 

For practical reasons, we need to ensure the Jacobian determinant and the inverse of the flow are 
fast to compute. In general, computing det A and A! explicitly takes O(D?) time. To reduce the 
cost, we can add structure to A. If A is diagonal, the cost becomes O(D). If A is triangular, the 


27 Jacobian determinant is the product of the diagonal elements, so takes O(D) time; inverting the 
28 flow requires solving the triangular system Au = æ — b, which can be done with backsubstitution in 
29 O(D?) time. 


The result of a triangular transformation depends on the ordering of the dimensions. To reduce 


31 sensitivity to this, and to encourage “mixing” of dimensions, we can multiply A with a permutation 
32 matrix, which has an absolute determinant of 1. We often use a permutation that reverses the indices 
33 at each layer or that randomly shuffles them. However, usually the permutation at each layer is fixed 


34 rather than learned. 


45 its Jacobian determinant is given by [| 
46 elementwise flow. 


For spatially structured data (such as images), we can define A to be a convolution matrix. For 


36 example, GLOW [KD18b] uses 1 x 1 convolution; this is equivalent to pointwise linear transformation 
37 across feature dimensions, but regular convolution across spatial dimensions. Two more general meth- 
38 ods for modeling d x d convolutions are presented in [HBW19], one based on stacking autoregressive 
39 convolutions, and the other on carrying out the convolution in the Fourier domain. 


= 23.2.2 Elementwise flows 


43 Let h : R + R be a scalar-valued bijection. We can create a vector-valued bijection f : R? — RP 


by applying h elementwise, that is, f(u) = (h(u1),...,h(up)). The function f is invertible, and 
D dh 


i21 gaz: A flow composed of such bijections is known as an 
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Figure 28.1: Non-linear squared flow (NLSq). Left: an invertible mapping consisting of 4 NLSq layers. 
Middle: red is the base distribution (Gaussian), blue is the distribution induced by the mapping on the left. 
Right: density of a 5-layer autoregressive flow using NLSq transformations and a Gaussian base density, 
trained on a mixture of 4 Gaussians. From Figure 5 of [ZR19b]. Used with kind permission of Zachary 
Ziegler. 


On their own, elementwise flows are limited, since they do not model dependencies between the 
elements. However, they are useful building blocks for more complex flows, such as coupling flows 
(Section 23.2.3) and autoregressive flows (Section 23.2.4), as we will see later. In this section, we 
discuss techniques for constructing scalar-valued bijections h : R > R for use in elementwise flows. 


23.2.2.1 Affine scalar bijection 


An affine scalar bijection has the form h(u;@) = au + b, where @ = (a,b) € R?. (This is a scalar 
version of an affine flow.) Its derivative gh is equal to a. It is invertible if and only if a # 0. In 
practice, we often parameterize a to be positive, for example by making it the exponential or the 
softplus of an unconstrained parameter. When a = 1, h(u;@) = u + b is often called an additive 


scalar bijection. 


23.2.2.2 Higher-order perturbations 


The affine scalar bijection is simple to use, but limited. We can make it more flexible by adding 
higher-order perturbations, under the constraint that invertibility is preserved. For example, Ziegler 
and Rush [ZR19b] propose the following, which they term non-linear squared flow: 
c 

1+ (du + e)?’ 

where 0 = (a,b,c, d,e) € R3. When c = 0, this reduces to the affine case. When c Æ 0, it adds an 
inverse-quadratic perturbation, which can induce multimodality as shown in Figure 23.1. Under the 
constraints a > zed and d > 0 the function becomes invertible, and its inverse can be computed 
analytically by solving a quadratic polynomial. 


h(u;@) =au+b4 


(23.9) 


23.2.2.3 Combinations of strictly monotonic scalar functions 


A strictly monotonic scalar function is one that is always increasing (has positive derivative everywhere) 
or always decreasing (has negative derivative everywhere). Such functions are invertible. Many 
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activation functions, such as the logistic sigmoid o(u) = 1/(1 + exp(—u)), are strictly monotonic. 

Using such activation functions as a starting point, we can build more flexible monotonic functions 
via conical combination (linear combination with positive coefficients) and function composition. 
Suppose h,...,x are strictly increasing; then the following are also strictly increasing: 


e ahı +--+ +aghx +b with ap > 0 (conical combination with a bias), 
e hy o---ohx (function composition). 


By repeating the above two constructions, we can build arbitrarily complex increasing functions. For 
example, a composition of conical combinations of logistic sigmoids is just an MLP where all weights 
are positive [Hua+18b]. 

The derivative of such a scalar bijection can be computed by repeatedly applying the chain rule, 
and in practice can be done with automatic differentiation. However, the inverse is not typically 
computable in closed form. In practice we can compute the inverse using bisection search, since the 
function is monotonic. 


23.2.2.4 Scalar bijections from integration 


A simple way to ensure a scalar function is strictly monotonic is to constrain its derivative to be 
positive. Let h’ = gh be this derivative. Wehenkel and Louppe [WL19] directly parameterize h’ with 
a neural network whose output is made positive via an ELU activation function shifted up by 1. 
They then integrate the derivative numerically to get the bijection: 


hei f “hl (t)dt +b, (23.10) 


where b is a bias. They call this approach unconstrained monotonic neural networks. 
The above integral is generally not computable in closed form. It can be, however, if h’ is 


29 constrained appropriately. For example, Jaini, Selby, and Yu [JSY19] take h’ to be a sum of K 
30 squared polynomials of degree L: 


K/L 2 
hi(u) =X ( ake ) (23.11) 
£=0 


k=1 


2 This makes h’ a non-negative polynomial of degree 2L. The integral is analytically tractable, and 
2 makes h an increasing polynomial of degree 2L + 1. For L = 0, h’ is constant, so h reduces to an 
2- affine scalar bijection. 


In these approaches, the derivative of the bijection can just be read off. However, the inverse is not 


2 analytically computable in general. In practice, we can use bisection search to compute the inverse 
= numerically. 


— 23.2.2.5 Splines 


44 Another way to construct monotonic scalar functions is using splines. These are piecewise-polynomial 
45 or piecewise-rational functions, parameterized in terms of K + 1 knots (uk, £) through which the 
46 spline passes. That is, we set h(u,) = £k, and define h on the interval (u,-1, uk) by interpolating 
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Figure 23.2: Illustration of a coupling layer œ = f(u). A bijection, with parameters determined by u”, is 


applied to u^ to generate x^; meanwhile x? = u” is passed through unchanged, so the mapping can be 


inverted. From Figure 3 of [KPB19]. Used with kind permission of Ivan Kobyzev. 


from 2,—1 to zk with a polynomial or rational function (ratio of two polynomials). By increasing the 
number of knots we can create arbitrarily flexible monotonic functions. 

Different ways to interpolate between knots give different types of spline. A simple choice is 
to interpolate linearly [Miil+19a], however this makes the derivative discontinuous at the knots. 
Interpolating with quadratic polynomials [Miil+19a] gives enough flexibility to make the derivative 
continuous. Interpolating with cubic polynomials [Dur+19], ratios of linear polynomials [DEL20], 
or ratios of quadratic polynomials [DBP19] allows the derivatives at the knots to be arbitrary 
parameters. 

The spline is strictly increasing if we take up_1 < Uk, £k—1 < £k, and make sure the interpolation 
between knots is itself increasing. Depending on the flexibility on the interpolating function, more 
than one interpolation may exist; in practice we choose one that is guaranteed to be always increasing 
(see references above for details). 

An advantage of splines is that they can be inverted analytically if the interpolating functions 
only contain low-degree polynomials. In this case, we compute u = h~!(z) as follows: first, we use 
binary search to locate the interval (x, -1, 2%) in which « lies; then, we analytically solve the resulting 
low-degree polynomial for u. 


23.2.3 Coupling flows 


In this section we describe coupling flows, which allow us to model dependencies between dimensions 
using arbitrary non-linear functions (such as deep neural networks). Consider a partition of the input 
u € R? into two subspaces, (u^, u”) € R? x R?~4, where d is an integer between 1 and D — 1. 
Assume a bijection f(-; 0) : R? + R¢ parameterized by @ and acting on the subspace R4. We define 


the function f : RP > RP given by æ = f(u) as follows: 


x^ = (u^; O(u?)) (23.12) 
r? =u”. (23.13) 
See Figure 23.2 for an illustration. The function f is called a coupling layer [DKB15; DSDB17], 
because it “couples” u^ and u” together though f and ©. We refer to flows consisting of coupling 
layers as coupling flows. 
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The parameters of Ë are computed by 0 = O(uë), where © is an arbitrary function called the 
conditioner. Unlike affine flows, which mix dimensions linearly, and elementwise flows, which do 
not mix dimensions at all, coupling flows can mix dimensions with a flexible non-linear conditioner O. 
In practice we often implement © as a deep neural network; any architecture can be used, including 
MLPs, CNNs, ResNets, etc. 


The coupling layer f is invertible, and its inverse is given by u = f~+(a), where 
u^ = £-1(a4; O(a? )) (23.14) 
u? = g”. (23.15) 
That is, f~! is given by simply replacing f with f~!. Because a? does not depend on u4, the 
Jacobian of f is block triangular: 
dx4/du4 dx4/duP JẸ) dx4/du® 

Jf) = = Jõu” de? /du®) TN 9 r .° (A 

Thus, det J( f) is equal to det J(f). 


We often define f to be an elementwise bijection, so that f-! and det J (f) are easy to compute. 
That is, we define: 


f(u4;0) = (h(ut;01),...,h(ug; a) , (23.17) 


where h(-;@;) is a scalar bijection parameterized by 0;. Any of the scalar bijections described in 
Section 23.2.2 can be used here. For example, h(-;0;) can be an affine bijection with 6; its scale and 
shift parameters (Section 23.2.2.1); or it can be a monotonic MLP with 6; its weights and biases 
(Section 23.2.2.3); or it can be a monotonic spline with 0; its knot coordinates (Section 23.2.2.5). 
There are many ways to define the partition of u into (w4,u?). A simple way is just to partition 


27 u into two halves. We can also exploit spatial structure in the partitioning. For example, if u is an 
28 image, we can partition its pixels using a “checkerboard” pattern, where pixels in “black squares” are 
29 in uA and pixels in “white squares” are in u? [DSDB17]. Since only part of the input is transformed 
30 by each coupling layer, in practice we typically employ different partitions along a coupling flow, to 
31 ensure all variables get transformed and are given the opportunity to interact. 


Finally, if Ë is an clementwise bijection, we can implement arbitrary partitions easily using a binary 


33 mask b as follows: 


xz =bOu+t (1—b) Of (u; O(bOu)), (23.18) 


232 where © denotes elementwise multiplication. A value of 0 in b indicates that the corresponding 
37 element in u is transformed (belongs to u^); a value of 1 indicates that it remains unchanged (belongs 
= tou 


Ri; 


As an example, we fit a masked coupling flow, created from piecewise rational quadratic splines, to 


= the two moons dataset. Samples from each layer of the fitted model are shown in Figure 23.3. 


~ 23.2.4 Autoregressive flows 


In this section we discuss autoregressive flows, which are flows composed of autoregressive bijections. 


45 Like coupling flows, autoregressive flows allow us to model dependencies between variables with 
46 arbitrary non-linear functions, such as deep neural networks. 
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Layer 1/8 Layer 2/8 Layer 3/8 Layer 4/8 


(a) (b) (c) (a) 


Layer 5/8 Layer 6/8 Layer 7/8 Layer 8/8 


(e) (f) (9) (h) 


Figure 23.8: (a) Two moons dataset. (b) Samples from a normalizing flow fit to this dataset. Generated by 
two_moons_nsf_normalizing_flow.ipynb. 


Suppose the input u contains D scalar elements, that is, u = (u1,..., up) € RP. We define an 
autoregressive bijection f : R? — R?, its output denoted by æ = (x1,...,2p) € RP, as follows: 


Each output x; depends on the corresponding input u; and all previous outputs £1:i—1 = (@1,..., 2-1). 
The function h(-; 0) : R > Risa scalar bijection (for example, one of those described in Section 23.2.2), 
and is parameterized by 0. The function ©; is a conditioner that outputs the parameters 0; that 
yield 2;, given all previous outputs 21.;_1. Like in coupling flows, O; can be an arbitrary non-linear 
function, and is often parameterized as a deep neural network. 

Because h is invertible, f is also invertible, and its inverse is given by: 


Ui = ht (xi; O;(a14-1)), = Leisg D: (23.20) 
An important property of f is that each output x; depends on uri = (u1,..., Ui), but not on 
Wi41:D = (Ui+1,--- Up); as a result, the partial derivative Ox; /Ou; is identically zero whenever j > i. 


Therefore, the Jacobian matrix J(f) is triangular, and its determinant is simply the product of its 
diagonal entries: 


D D 
a - pa 


(23.21) 


In other words, the autoregressive structure of f leads to a Jacobian determinant that can be 
computed efficiently in O(D) time. 

Although invertible, autoregressive bijections are computationally asymmetric: evaluating f is 
inherently sequential, whereas evaluating f~! is inherently parallel. That is because we need 21.;_1 to 
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distributi 
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Figure 23.4: (a) Affine autoregressive flow with one layer. In this figure, u is the input to the flow (sample 
from the base distribution) and æ is its output (sample from the transformed distribution). (b) Inverse of the 
above. From [Jan18]. Used with kind permission of Eric Jang. 


compute x,;; therefore, computing the components of x must be done sequentially, by first computing 
xı, then using it to compute x2, then using x, and x2 to compute x3, and so on. On the other hand, 
computing the inverse can be done in parallel for each u;, since u does not appear on the right-hand 
side of Equation (23.20). Hence, in practice it is often faster to compute f~t than to compute f, 
assuming h and h~! have similar computational cost. 


23.2.4.1 Affine autoregressive flows 


27 For a concrete example, we can take h to be an affine scalar bijection (Section 23.2.2.1) parameterized 


28 by a log scale œ and a bias u. Such autoregressive flows are known as affine autoregressive flows. 
The parameters of the ith component, a; and u;i, are functions of x.;-1, so f takes the following 
form: 

Ti = Ui explai(£i:i—1)) + lilii). (23.22) 
This is illustrated in Figure 23.4(a). We can invert this by 
ui = (xi = fi; (@14-1)) exp(—a;(a@1-;-1)). (23.23) 
This is illustrated in Figure 23.4(b). Finally, we can calculate the log absolute Jacobian determinant 
by 
D D 
log |det J(f)| = log Tortaten} = >) a(r). (23.24) 
i=1 i=1 


Let us look at an example of an affine autoregressive flow on a 2d density estimation problem. 


44 Consider an affine autoregressive flow x = (x1,%2) = f(u), where u ~ N (0,I) and f is a single 


autoregressive bijection. Since xı is an affine transformation of u ~ N (0,1), it is Gaussian with 


46 mean jy and standard deviation cı = expa,. Similarly, if we consider x; fixed, x2 is an affine 
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(b) 


Figure 23.5: Density estimation with affine autoregressive flows, using a Gaussian base distribution. (a) True 
density. (b) Estimated density using a single autoregressive layer with ordering (11,12). On the left (contour 
plot) we show p(a). On the right (green dots) we show samples of u = f~'(a), where x is sampled from the 
true density. (c) Same as (b), but using 5 autoregressive layers and reversing the variable ordering after each 
layer. Adapted from Figure 1 of [PPM17]. Used with kind permission of Iain Murray. 


transformation of ug ~ N(0,1), so it is conditionally Gaussian with mean u2(x1) and standard 
deviation o2(21) = exp a2(x1). Thus, a single affine autoregressive bijection will always produce a 
distribution with Gaussian conditionals, that is, a distribution of the following form: 


p(t1, £2) = p(21) p(xe|x1) = N (x1 | M1, of) N (x2|M2 (a1), 02(21)*) (23.25) 


This result generalizes to an arbitrary number of dimensions D. 

A single affine bijection is not very powerful, regardless of how flexible the functions a2(a#1) and 
j12(a1) are. For example, suppose we want to fit the cross-shaped density shown in Figure 23.5(a) 
with such a flow. The resulting maximum-likelihood fit is shown in Figure 23.5(b). The red contours 
show the predictive distribution, p(a), which clearly fails to capture the true distribution. The green 
dots show transformed versions of the data samples, p(w); we see that this is far from the Gaussian 
base distribution. 

Fortunately, we can obtain a better fit by composing multiple autoregressive bijections (layers), 
and reversing the order of the variables after each layer. For example, Figure 23.5(c) shows the 
results of an affine autoregressive flow with 5 layers applied to the same problem. The red contours 
show that we have matched the empirical distribution, and the green dots show we have matched the 
Gaussian base distribution. 

Note that another way to obtain a better fit is to replace the affine bijection h with a more flexible 
one, such as a monotonic MLP (Section 23.2.2.3) or a monotonic spline (Section 23.2.2.5). 


23.2.4.2 Masked autoregressive flows 


As we have seen, the conditioners O; can be arbitrary non-linear functions. The most straightforward 
way to parameterize them is separately for each i, for example by using D separate neural networks. 
However, this can be parameter-inefficient for large D. 

In practice, we often share parameters between conditioners by combining them into a single 
model © that takes in x and outputs (01,...,8p). For the bijection to remain autoregressive, we 
must constrain © so that 0; depends only on a .;_; and not on 2%;.p. One way to achieve this is to 
start with an arbitrary neural network (an MLP, a CNN, a ResNet, etc.), and drop connections (for 
example, by zeroing out weights) until 8; is only a function of £1:i—1. 
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Ti = u; exp(ai) IO “ies 


transformed 
distribution 


base 
distribution 


Figure 23.6: Inverse autoregressive flow that uses affine scalar bijections. In this figure, u is the input to the 
flow (sample from the base distribution) and x is its output (sample from the transformed distribution) From 
[Jan18]. Used with kind permission of Eric Jang. 


An example of this approach is the masked autoregressive flow (MAF) model of [PPM17]. 
This model is an affine autoregressive flow combined with permutation layers, as we described in 
Section 23.2.4.1. MAF implements the combined conditioner © as follows: it starts with an MLP, 
and then multiplies (elementwise) the weight matrix of each layer with a binary mask of the same 
size (different masks are used for different layers). The masks are constructed using the method of 
[Ger+15]. This ensures that all computational paths from x; to 0; are zeroed out whenever j > i, 
effectively making 0; only a function of £1:;—1. Still, evaluating the masked conditioner © has the 
same computational cost as evaluating the original (unmasked) MLP. 

The key advantage of MAF (and of related models) is that, given a, all parameters (0,,...,9p) 
can be computed efficiently with one neural network evaluation, so the computation of the inverse 


28 f`! is fast. Thus, we can efficiently evaluate the probability density of the flow model for arbitrary 


datapoints. However, in order to compute f, the conditioner © must be called a total of D times, 
since not all entries of æ are available to start with. Thus, generating new samples from the flow is 


31 D times more expensive than evaluating its probability density function. This makes MAF suitable 


for density estimation, but less so for data generation. 


34 23.2.4.3 Inverse autoregressive flows 


As we have seen, the parameters 0; that yield the ith output x; are functions of the previous outputs 
2 1.,;-1. This ensures that the Jacobian J(f) is triangular, and so its determinant is efficient to 
compute. 

However, there is another possibility: we can make 6; a function of the previous inputs instead, 
that is, a function of w.;-1. This leads to the following bijection, which is known as inverse 
autoregressive: 


Like its autoregressive counterpart, this bijection has a triangular Jacobian whose determinant is 


Ar a . Figure 23.6 illustrates an inverse autoregressive flow, for the case 
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To see why this bijection is called “inverse autoregressive”, compare Equation (23.26) with Equa- 
tion (23.20). The two formulas differ only notationally: we can get from one to the other by swapping 
u with x and h with h~!. In other words, the inverse autoregressive bijection corresponds to a direct 
parameterization of the inverse of an autoregressive bijection. 

Since inverse autoregressive bijections swap the forward and inverse directions of their autoregressive 
counterparts, they also swap their computational properties. This means that the forward direction f 
of an inverse autoregressive flow is inherently parallel and therefore fast, whereas its inverse direction 
f-t is inherently sequential and therefore slow. 

An example of an inverse autoregressive flow is their namesake [AF model of [Kin+16]. IAF uses 
affine scalar bijections, masked conditioners and permutation layers, so it is precisely the inverse of 
the MAF model described in Section 23.2.4.2. Using IAF, we can generate u in parallel from the 
base distribution (using, for example, a diagonal Gaussian), and then sample each element of x in 
parallel. However, evaluating p(x) for an arbitrary datapoint æ is slow, because we have to evaluate 
each element of u sequentially. Fortunately, evaluating the likelihood of samples generated from IAF 
(as opposed to externally provided samples) incurs no additional cost, since in this case the u; terms 
will already have been computed. 

Although not so suitable for density estimation or maximum-likelihood training, IAFs are well- 
suited for parameterizing variational posteriors in variational inference. This is because in order to 
estimate the variational lower bound (ELBO), we only need samples from the variational posterior 
and their associated probability densities, both of which are efficient to obtain. See Section 23.1.2.2 
for details. 

Another useful application of IAFs is training them to mimic models whose probability density is 
fast to evaluate but which are slow to sample from. A notable example is the parallel wavenet 
model of [Oor-+18]. This model is an IAF ps that it trained to mimic a pretrained wavenet model p; 
by minimizing the KL divergence Dx (ps || p+). This KL can be easily estimated by first sampling 
from p, and then evaluating log p, and log p; at those samples, operations which are all efficient for 
these models. After training, we obtain an IAF that can generate audio of similar quality as the 
original wavenet, but can do so much faster. 


23.2.4.4 Connection with autoregressive models 


Autoregressive flows can be thought of as generalizing autoregressive models of continuous random 
variables, discussed in Section 22.1. Specifically, any continuous autoregressive model can be 
reparameterized as a one-layer autoregressive flow, as we describe below. 


Consider a general autoregressive model over a continuous random variable x = (x1,...,%p) € RP 
written as 
D 
p(x) = ] [ 2:(2:10:) where 0; = O;(#1.;-1). (23.27) 
i=1 


In the above expression, pi(x;|0;) is the i-th conditional distribution of the autoregressive model, 
whose parameters 9; are arbitrary functions of the previous variables £1:;—-1. For example, p;(x;|0;) 
can be a mixture of one-dimensional Gaussian distributions, with 0; representing the collection of its 
means, variances and mixing coefficients. 

Now consider sampling a vector æ from the autoregressive model, which can be done by sampling 
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one element at a time as follows: 
Ti ~ pilzilOi(£ii-1)) for i = 1,...,D. (23.28) 


Each conditional can be sampled from using inverse transform sampling (Section 11.3.1). Let U(0, 1) 
be the uniform distribution on the interval [0,1], and let CDF;(x;|6;) be the cumulative distribution 
function of the i-th conditional. Sampling can be written as: 


aj = CDF; 1(uj|O;(ai;-1)) where uj ~ U(0,1). (23.29) 


Comparing the above expression with the definition of an autoregressive bijection in Equation (23.19), 
we see that the autoregressive model has been expressed as a one-layer autoregressive flow whose base 
distribution is uniform on [0,1]? and whose scalar bijections correspond to the inverse conditional 
CDFs. Viewing autoregressive models as flows this way has an important advantage, namely that it 
allows us to increase the flexibility of an autoregressive model by composing multiple instances of it 
in a flow, without sacrificing the overall tractability. 


23.2.5 Residual flows 


A residual network is a composition of residual connections, which are functions of the form 
f(u) =u+F(u). The function F : R? — R?” is called the residual block, and it computes the 
difference between the output and the input, f(u) — u. 

Under certain conditions on F, the residual connection f becomes invertible. We will refer to flows 
composed of invertible residual connections as residual flows. In the following, we describe two 
ways the residual block F can be constrained so that the residual connection f is invertible. 


23.2.5.1 Contractive residual blocks 


= One way to ensure the residual connection is invertible is to choose the residual block to be a 
= contraction. A contraction is a function F whose Lipschitz constant is less than 1; that is, there 
= exists 0 < L < 1 such that for all uq and uz we have: 


|E (u1) — F(u2)|| < Lju — ual). (23.30) 


The invertibility of f(u) = u + F(u) can be shown as follows. Consider the mapping g(u) = 
a — F(u). Because F is a contraction, g is also a contraction. So, by Banach’s fixed-point theorem, 
36 g has a unique fixed point u,. Hence we have 
u, = x — F(u.) (23.31) 
> u, +F(u.)=£ (23.32) 
=> f(u.)=a. (23.33) 


Because u, is unique, it follows that u, = f~1(2). 
An example of a residual flow with contractive residual blocks is the iResNet model of [Beh+19]. 
The residual blocks of iResNet are convolutional neural networks, that is, compositions of convolutional 


45 layers with non-linear activation functions. Because the Lipschitz constant of a composition is less or 
46 equal to the product of the Lipschitz constants of the individual functions, it is enough to ensure the 
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convolutions are contractive, and to use increasing activation functions with slope less or equal to 1. 
The iResNet model ensures the convolutions are contractive by applying spectral normalization to 
their weights [Miy+ 18a]. 

In general, there is no analytical expression for the inverse f—~!. However, we can approximate 
f=! (x) using the following iterative procedure: 


Un = Q(Un—1) = © — F(uy_1). (23.34) 


Banach’s fixed-point theorem guarantees that the sequence ug, U1, Uz, ... will converge to u, = f~'(a) 
for any choice of uo, and it will do so at a rate of O(L"), where L is the Lipschitz constant of g 
(which is the same as the Lipschitz constant of F). In practice, it is convenient to choose uo = x. 

In addition, there is no analytical expression for the Jacobian determinant, whose exact computation 
costs O(D*). However, there is a computationally efficient stochastic estimator of the log Jacobian 
determinant. The idea is to express the log Jacobian determinant as a power series. Using the fact 
that f(a) = x + F(x), we have 


log | det J(f)| = log | det(I + J(F))| = 2 a eae [J(F)*]. (23.35) 


This power series converges when the matrix norm of J(F) is less than 1, which here is guaranteed 
exactly because F is a contraction. The trace of J(F)* can be efficiently approximated using 
Jacobian-vector products via the Hutchinson trace estimator [Ski89; Hut89; Mey+21]: 


tr[J(F)*] ~ v! J(F)*v, (23.36) 


where v is a sample from a distribution with zero mean and unit covariance, such as M (0, I). 
Finally, the infinite series can be approximated by a finite one either by truncation [Beh+19], which 
unfortunately yields a biased estimator, or by employing the Russian-roulette estimator [Che+19], 
which is unbiased. 


23.2.5.2 Residual blocks with low-rank Jacobian 


There is an efficient way of computing the determinant of a matrix which is a low-rank perturbation 
of an identity matrix. Suppose A and B are matrices, where A is D x M and Bis M x D. The 
following formula is known as the Weinstein—Aronszajn identity”, and is a special case of the 
more general matrix determinant lemma: 


det(Ip + AB) = det (Ij, + BA). (23.37) 


We write Ip and Ij, for the D x D and M x M identity matrices respectively. The significance of 
this formula is that it turns a D x D determinant that costs O(D*) into an M x M determinant 
that costs O(M?). If M is smaller than D, this saves computation. 

With some restrictions on the residual block F : R? — RP, we can apply this formula to compute 
the determinant of a residual connection efficiently. The trick is to create a bottleneck inside F. We 
do that by defining F = F, o F4, where F; : R? > RM, F, : RM > R? and M < D. The chain 


2. See https://en.wikipedia. org/wiki/Weinstein-Aronszajn_identity. 
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rule gives J(F) = J(F2)J(Fi), where J(F2) is D x M and J(F1) is M x D. Now we can apply our 
determinant formula as follows: 


det J(f) = det(Ip + J(F)) = det(Ip + J(F2)J(F1)) = det(Im + J(F1)J(F2)). (23.38) 


Since the final determinant costs O(M?), we can make the Jacobian determinant efficient by reducing 
M, that is, by narrowing the bottleneck. 

An example of the above is the planar flow of [RM15]. In this model, each residual block is an 
MLP with one hidden layer and one hidden unit. That is, 


f(u) = ut vo(wlutd), (23.39) 


where v € R?, w € RP and b € R are the parameters, and ø is the activation function. The residual 
block is the composition of Fı(u) = w' u +b and Fo(z) = va(z), so M = 1. Their Jacobians 
are J(Fi)(w) = w! and J(F2)(z) = vo'(z). Substituting these in the formula for the Jacobian 
determinant we obtain: 


det J(f)(u) =1+w'vo'(w'u+bd), (23.40) 


which can be computed efficiently in O(D). Other examples include the circular flow of [RM15] 
and the Sylvester flow of [Ber+18]. 

This technique gives an efficient way of computing determinants of residual connections with 
bottlenecks, but in general there is no guarantee that such functions are invertible. This means that 
invertibility must be satisfied on a case-by-case basis. For example, the planar flow is invertible when 
g is the hyperbolic tangent and w'v > —1, but otherwise it may not be. 


27 23.2.6 Continuous-time flows 
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So far we have discussed flows that consist of a sequence of bijections f;,..., fn. Starting from some 


30 input £o = u, this creates a sequence of outputs £1,..., £y where £n = fn(£n-1). However, we can 


also have flows where the input is transformed into the final output in a continuous way. That is, 


32 we start from ao = x(0), create a continuously-indexed sequence a(t) for t € [0, T] with some fixed 


T, and take x(T) to be the final output. Thinking of t as analogous to time, we refer to these as 
continuous-time flows. 

The sequence a(t) is defined as the solution to a first-order ordinary differential equation (ODE) 
of the form: 


dx 


pt) = F(z), t) (23.41) 


40 The function F : R? x [0,7] — R? is a time-dependent vector field that parameterizes the ODE. If 
41 we think of a(t) as the position of a particle in D dimensions, the vector F(a(t),t) determines the 
42 particle’s velocity at time t. 


The flow (for time T) is a function f : RP? > RP that takes in an input £o, solves the ODE 
with initial condition «(0) = a, and returns 2(T). The function f is a well-defined bijection if the 


45 solution to the ODE exists for all t € [0,7] and is unique. These conditions are not generally satisfied 
46 for arbitrary F, but they are if F(-,¢) is Lipschitz continuous with a Lipschitz constant that does not 
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depend on t. That is, f is a well-defined bijection if there exists a constant L such that for all £1, x2 
and t € [0, T] we have: 


|E (x1, t) — F(x2,t)|| < L\|a1 — xl. (23.42) 


This result is a consequence of the Picard—Lindeléf theorem for ODEs.’ In practice, we can 
parameterize F using any choice of model, provided the Lipschitz condition is met. 

Usually the ODE cannot be solved analytically, but we can solve it approximately by discretizing 
it. A simple example is Euler’s method, which corresponds to the following discretization for some 
small step size € > 0: 


a(t +e) = a(t) + e€F(ax(t),t). (23.43) 


This is equivalent to a residual connection with residual block «F(-,t), so the ODE solver can be 
thought of as a deep residual network with O(T/e) layers. A smaller step size leads to a more 
accurate solution, but also to more computation. There are several other solution methods varying 
in accuracy and sophistication, such as those in the broader Runge-Kutta family, some of which use 
adaptive step sizes. 

The inverse of f can be easily computed by solving the ODE in reverse. That is, to compute 
f~\(xr) we solve the ODE with initial condition x(T) = ær and return x(0). Unlike some other 
flows (such as autoregressive flows) which are more expensive to compute in one direction than in 
the other, continuous-time flows require the same amount of computation in either direction. 

In general, there is no analytical expression for the Jacobian determinant of f. However, we can 
express it as the solution to a separate ODE, which we can then solve numerically. First, we define 
fi : RP — RP to be the flow for time t, that is, the function that takes a9, solves the ODE with 
initial condition æ(0) = £o and returns x(t). Clearly, fo is the identity function and fr = f. Let us 
define L(t) = log | det J(f;)(xo)|. Because fo is the identity function, L(0) = 0, and because fr = f, 
L(T) gives the Jacobian determinant of f that we are interested in. It can be shown that L satisfies 
the following ODE: 


ZA = [IEUD] (23.44) 


That is, the rate of change of L at time t is equal to the Jacobian trace of F(-, t) evaluated at a(t). So 
we can compute L(T) by solving the above ODE with initial condition L(0) = 0. Moreover, we can 
compute «(T) and L(T) simultaneously, by combining their two ODEs into a single ODE operating 
on the extended space (a, L). 

An example of a continuous-time flow is the Neural ODE model of [Che+18c], which uses a 
neural network to parameterize F. To avoid backpropagating gradients through the ODE solver, 
which can be computationally demanding, they use the adjoint sensitivity method to express the 
time evolution of the gradient with respect to a(t) as a separate ODE. Solving this ODE gives the 
required gradients, and can be thought of as the continuous-time analogue of backpropagation. 

Another example is the FFJORD model of [Gra+18]. This is similar to the Neural ODE model, 
except that it uses the Hutchinson trace estimator to approximate the Jacobian trace of F(-,t). 
This usage of the Hutchinson trace estimator is analogous to that in contractive residual flows 
(Section 23.2.5.1), and it speeds up computation in exchange for a stochastic (but unbiased) estimate. 


3. See https: //en.wikipedia. org/wiki/Picard-Lindel/C3/B6f_theorem 
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23.3 Applications 


In this section, we highlight some applications of flows for canonical probabilistic machine learning 
tasks. 


23.3.1 Density estimation 


Flow models allow exact density computation and can be used to fit multi-modal densities to observed 
data. (see Figure 23.3 for an example). An early example is Gaussianization [CG00] who applied this 
idea to fit low-dimensional densities. Tabak and Vanden-Eijnden [TVE10] and Tabak and Turner 
[T'T13] introduced the modern idea of flows (including the term ‘normalizing flows’), describing a 
flow as a composition of simpler maps. Deep density models [R.A13] was one of the first to use neural 
networks for flows to parametrize high-dimensional densities. There has been a rich line of follow-up 
work including NICE [DKB15] and Real NVP [DSDB17]. (NVP stands for “non-volume preserving”, 
which refers to the fact that the Jacobian of the transform is not unity.) Masked autoregressive flows 
(Section 23.2.4.2) further improved performance on unconditional and conditional density estimation 
tasks. 

Flows can be used for hybrid models which model the joint density of inputs and targets p(x, y), as 
opposed to discriminative classification models which just model the conditional p(y|a) and density 
models which just model the marginal p(a). Nalisnick et al. [Nal+19b] proposed a flow-based hybrid 
model using invertible mappings for representation learning and showed that the joint density p(x, y) 
can be computed efficiently, which can be useful for downstream tasks such as anomaly detection, 
semi-supervised learning and selective classification. Flow-based hybrid models are memory-efficient 
since most of the parameters are in the invertible representation which are shared between the 
discriminative and generative models; furthermore, the density p(x, y) can be computed in a single 


27 forward pass leading to computational savings. Residual flows [Che+19] use invertible residual 


mappings [Beh-+19] for hybrid modeling which further improves performance. Flows have also been 
used to fit densities to embeddings [Zha+20b; CZG20] for anomaly detection tasks. 


= 23.3.2 Generative Modeling 


33 Another task is generation, which involves generating novel samples from a fitted model p* (æ). 


Generation is a popular downstream task for normalizing flows, which have been applied for different 


35 data modalities including images, video, audio, text and structured objects such as graphs and point 
36 clouds. Images are arguably the most popular modality for deep generative models: GLOW [KD18b] 
37 was one of the first flow-based models to generate compelling high-dimensional images, and has been 
38 extended to video to produce RGB frames [Kum-+19b]; residual flows [Che+19] have also been shown 
39 to produce sharp images. 


Oord et al. [Oor+18] used flows for audio synthesis by distilling WaveNet into an IAF (Sec- 
tion 23.2.4.3), which enables faster sampling than WaveNet. Other flow models for audio include 


42 WaveFLOW [PVC19] and FlowWaveNet [Kim+19], which directly speed up WaveNet using coupling 
43 layers. 


Flows have been also used for text. Tran et al. [Tra+19] define a discrete flow over a vocabulary 


45 for language-modeling tasks. Another popular approach is to define a latent variable model with 
46 discrete observation space but a continuous latent space. For example, Ziegler and Rush [ZR19a] use 
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normalizing flows in latent space for language modeling. 


23.3.3 Inference 


Normalizing flows have been used for probabilistic inference. Rezende and Mohamed [RM15] 
popularized normalizing flows in machine learning, and showed how they can be used for modeling 
variational posterior distributions in latent variable models. Various extensions such as Householder 
flows [TW16], inverse autoregressive flows [Kin+16], multiplicative normalizing flows [LW17] and 
Sylvester flows [Ber-+18] have been proposed for modeling the variational posterior for latent variable 
models as well as posteriors for Bayesian neural networks. 

Flows have been used as complex proposal distributions for importance sampling; examples include 
neural importance sampling [Miil+19b] and Boltzmann generators [Noé+19]. Hoffman et al. [Hof+19] 
used flows to improve the performance of Hamiltonian Monte Carlo (Section 12.5) by defining bijective 
transformations to transform random variables to simpler distributions and performing HMC in that 
space instead. 

Finally, flows can be used in the context of simulation-based inference, where the likelihood function 
of the parameters is not available, but simulating data from the model is possible. The main idea is 
to train a flow on data simulated from the model in order to approximate the posterior distribution 
or the likelihood function. The flow model can also be used to guide simulations in order to make 
inference more efficient [PSM19; GNM19]. This approach has been used for inference of simulation 
models in cosmology [Als+19] and computational neuroscience [Gon+20]. 
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24 Energy-based models 


This chapter was co-authored with Yang Song and Durk Kingma. 


24.1 Introduction 


We have now seen several ways of defining deep generative models, including VAEs (Chapter 21), 
auto-regressive models (Chapter 22) and normalizing flows (Chapter 23). All of the above models 
can be formulated in terms of directed graphical models (Chapter 4), where we generate the data 
one step at a time, using locally normalized distributions. In some cases, it is easier to specify a 
distribution in terms of a set of constraints that valid samples must satisfy, rather than a generative 
process. This can be done using an undirected graphical model (Chapter 4). 

Energy-based models or EBM can be written as a Gibbs distribution as follows: 


_ exp(—Eo(x)) 


z (24.1) 


po(x) 


where Ee(x) > 0 is known as the energy function with parameters 0, and Zø is the partition 
function: 


Zo = J exol Eat) dx (24.2) 


This is constant w.r.t. x but is a function of 0. Since EBMs do not usually make any Markov 
assumptions (unlike graphical models), evaluating this integral is usually intractable. Consequently 
we usually need to use approximate methods, such as annealed importance sampling, discussed in 
Section 11.5.4.1. 

The advantage of an EBM over other generative models is that the energy function can be any 
kind of function that returns a non-negative scalar; it does not need to integrate to 1. This allows 
one to use a variety of neural network architectures for defining the energy. As such, EBMs have 
found wide applications in many fields of machine learning, including image generation |Ngi+11; 
Xie+16; DM19b], discriminative learning [Gra+20b], natural processing [Mik+13; Den+20], density 
estimation [Wen+19a; Son+19] and reinforcement learning [Haa+17; Haa+18al, to list a few. (More 
examples can be found at https: //github.com/yataobian/awesome-ebm.) 
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energy A energy B energyA+B 


Figure 24.1: Combining two energy functions in 2d by summation, which is equivalent to multiplying the 
corresponding probability densities. We also illustrate some sampled trajectories towards high probability (low 
energy) regions. From Figure 14 of [DM19a]. Used with kind permission of Yilun Du. 


24.1.1 Example: Products of experts (PoE) 


As an example of why energy based models are useful, suppose we want to create a generative model 
of proteins that are thermally stable at room temperature, and which bind to the COVID-19 spike 
receptor. Suppose p;(a) can generate stable proteins and p2(a) can generate proteins that bind. 
(For example, both of these models could be autoregressive sequence models, trained on different 
datasets.) We can view each of these models as “experts” about a particular aspect of the data. 
On their own, they are not an adequate model of the data that we have (or want to have), but we 
can then combine them, to represent the conjunction of features, by computing a product of 
experts (PoE) [Hin02]: 


-pi (x) po(x) (24.3) 


Pi2(x) = Z 


This will assign high probability to proteins that are stable and which bind, and low probability to 
all others. By contrast, a mixture of experts would either generate from pı or from p2, but would 


33 not combine features from both. 


If the experts are represented as energy based models (EBM), then the PoE model is also an EBM, 
with an energy given by 


Elx) = E (a) + Eal) (24.4) 


Intuitively, we can think of each component of energy as a “soft constraint” on the data. This idea 
is illustrated in Figure 24.1. 


~= 24.1.2 Computational difficulties 


44 Although the flexibility of EBMs can provide significant modeling advantages, computation of the 
45 likelihood and drawing samples from the model are generally intractable. In this chapter, we will 
46 discuss a variety of approximate methods to solve these problems. 
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24.2 Maximum Likelihood Training 


The de facto standard for learning probabilistic models from i.i.d. data is maximum likelihood 
estimation (MLE). Let pọ(x) be a probabilistic model parameterized by 0, and paata(x) be the 
underlying data distribution of a dataset. We can fit pe(x) to Paata(x) by maximizing the expected 
log-likelihood function over the data distribution, defined by 


£(8) = Ex pasta (x) [log po (x)] (24.5) 


as a function of 0. Here the expectation can be easily estimated with samples from the dataset. 
Maximizing likelihood is equivalent to minimizing the KL divergence between paata(x) and pe(x), 
because 


£(0) = —Dkxt (Paata(X) || pe(x)) + const (24.6) 


where the constant is equal to Ey~paata (x) [log Paata (X)] which does not depend on @. 

We cannot usually compute the likelihood of an EBM because the normalizing constant Ze 
is often intractable. Nevertheless, we can still estimate the gradient of the log-likelihood with 
MCMC approaches, allowing for likelihood maximization with stochastic gradient ascent [You99]. In 
particular, the gradient of the log-probability of an EBM decomposes as a sum of two terms: 


Ve log pe(x) = —VoEo (x) = Vo log Zo. (24.7) 
The first gradient term, —VoEe (xX), is straightforward to evaluate with automatic differentiation. The 


challenge is in approximating the second gradient term, Ve log Ze, which is intractable to compute 
exactly. This gradient term can be rewritten as the following expectation: 


Vo log Zo = Va log | exp(-Eo(x))dx (24.8) 


a expl Eol))dr) Vo | exp(-Eo(x))dx (24.9) 
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(24.10) 


@ ( J exp(Ea(x))tx) J exp(-E0(20)(-WoE (x) (24.11) 


= (J exp(=Eo(x))dx) exp(—E€@(x))(—VeEe(x))dx (24.12) 
(i) f = E E A (24.13) 
z / pe(x)(—VeEe(x))dx (24.14) 
= bopa Gð [—Vo€e(x)], (24.15) 
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where steps (i) and (ii) are due to the chain rule of gradients, and (iii) and (iv) are from definitions 
in Equations (24.1) and (24.2). Thus, we can obtain an unbiased Monte Carlo estimate of the 
log-likelihood gradient by using 


Vo log Zo ~ —= X. VoEol(%s), (24.16) 


where X, ~ po (x), i.e., a random sample from the distribution over x given by the EBM. Therefore, 
as long as we can draw random samples from the model, we have access to an unbiased Monte 
Carlo estimate of the log-likelihood gradient, allowing us to optimize the parameters with stochastic 
gradient ascent. 

Much of the literature has focused on methods for efficient MCMC sampling from EBMs. We 
discuss some of these methods below. 


24.2.1 Gradient-based MCMC methods 


Some efficient MCMC methods, such as Langevin MCMC (Section 12.5.6) or Hamiltonian Monte 
Carlo (Section 12.5), make use of the fact that the gradient of the log-probability w.r.t. x (known 
as the score function) is equal to the (negative) gradient of the energy, and is therefore easy to 
calculate: 


Vx log po (x) = —VxEo (x) — Vx log Zo = —VxEo(x). (24.17) 
a 


For example, when using Langevin MCMC to sample from pg(x), we first draw an initial sample x° 


26 from a simple prior distribution, and then simulate an overdamped Langevin diffusion process for K 
27 steps with step size € > 0: 


2 
xh xk + OV lope) teak, b= 0,10. K = 1, (24.18) 
Sn 


=—V E(x) 


~~ where z* ~ N’(0,1) is a Gaussian noise term. We show an example of this process in Figure 24.3d. 


When e + 0 and K —> œ, x* is guaranteed to distribute as pg(x) under some regularity conditions. 


— In practice we have to use a small finite e€, but the discretization error is typically negligible, or can 
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= be corrected with a Metropolis-Hastings step (Section 12.2), leading to the Metropolis-Adjusted 
~~ Langevin Algorithm (Section 12.5.6). 


24.2.2  Contrastive divergence 


2 Running MCMC till convergence to obtain a sample x ~ pe(x) can be computationally expensive. 
= Therefore we typically need approximations to make MCMC-based learning of EBMs practical. One 
== popular method for doing so is contrastive divergence (CD) [Hin02]. In CD, one initializes the 
*° MCMC chain from the datapoint x, and proceeds to perform MCMC for a fixed number of steps. 


One can show that T steps of CD minimizies the following objective: 
CDr = Dut (Po || Poo) — Dux (pr || Poo) (24.19) 
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where pr is the distribution over x after T MCMC updates, and pg is the data distribution. Typically 
we can get good results with a small value of T, sometimes just T = 1. We give the details below. 


24.2.2.1 Fitting RBMs with CD 


CD was initially developed to fit a special kind of latent variable EBM known as a restricted 
Boltzmann machine (Section 4.3.3.2). This model was specifically designed to support fast block 
Gibbs sampling, which is required by CD (and can also be exploited by standard MCMC-based 
learning methods [AHS85].) 

For simplicity, we will assume the hidden and visible nodes are binary, and we use 1-step contrastive 
divergence. As discussed in Supplementary Section 4.3.1, the binary RBM has the following energy 
function: 


D K 


K 
E(x, 230) =X X tazkWar +X taba +Y zkCk (24.20) 


d=1 k=1 d=1 k=1 


(Henceforth we will drop the unary (bias) terms, which can be emulated by clamping z, = 1 or za = 1.) 
This is a loglinear model where we have one binary feature per edge. Thus from Equation (4.135) the 
gradient of the log-likelihood is given by the clamped expectations minus the unclamped expectations: 


Np 


ae 1 
Gia NS XC E [2az%|@n, 0] — E [xaz]0] (24.21) 


n=1 


We can rewrite the above gradient in matrix-vector form as follows: 


Vw l = Epp (æ)p(z|z,0) [£27] — Ep(z,a\a) [227] (24.22) 


(We can derive a similar expression for the gradient of the bias terms by setting xq = 1 or zk = 1.) 

The first term in the expression for the gradient in Equation (24.21), when a is fixed to a data 
case, is sometimes called the clamped phase, and the second term, when g is free, is sometimes 
called the unclamped phase. When the model expectations match the empirical expectations, the 
two terms cancel out, the gradient becomes zero and learning stops. 

We can also make a connection to the principle of Hebbian learning in neuroscience. In particular, 
Hebb’s rule says that the strength of connection between two neurons that are simultaneously active 
should be increased. (This theory is often summarized as “Cells that fire together wire together”.') 
The first term in Equation (24.21) is therefore considered a Hebbian term, and the second term an 
anti-Hebbian term, due to the sign change. 

We can leverage the Markov structure of the bipartite graph to approximate the expectations as 
follows: 


Zn ~ p(z|£n, 0) (24.23) 
En ~ plæ|zn, 0) (24.24) 
zi ~ plz|æ, 0) (24.25) 


1. See https://en.wikipedia.org/wiki/Hebbian_theory. 
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Figure 24.2: Illustration of contrastive divergence sampling for an RBM. The visible nodes are initialized 
at an example drawn from the data set. Then we sample a hidden vector, then another visible vector, etc. 
Eventually (at “infinity”) we will be producing samples from the joint distribution p(x, z|0). 


We can think of z/, as the model’s best attempt at reconstructing £n after being encoded and then 
decoded by the model. Such samples are sometimes called fantasy data. See Figure 24.2 for an 
illustration. Given these samples, we then make the approximation 
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ipejo) [@2"] ~ @n(Zh, 


a 


In practice, it is common to use E [z|a/,] instead of a s 


(24.26) 


ampled value z/, in the above expression, since 


this reduces the variance. However, it is not valid to use E [z|a,,] instead of sampling zn ~ p(z|£n) in 
Equation (24.23), because then each hidden unit would be able to pass more than 1 bit of information, 
so it would not act as much of a bottleneck. 
The whole procedure is summarized in Algorithm 38. For more details, see [Hin10; Swe+10]. 
Algorithm 38: CD-1 training for an RBM with binary hidden and visible units 
32 1 Initialize weights W € R?** randomly 

2 for t= 1,2,... do 

3 for each minibatch of size B do 

4 Set minibatch gradient to zero, g := 0 

5 for each case £n in the minibatch do 

6 Compute p,, = E [z|£n, W 

7 Sample zn ~ p(z|£n, W) 

8 Sample x, ~ plæ|zn, W) 

9 Compute u’, = E[z|a),,W 
10 Compute gradient Vw = (an)(u,,)" — (£1 )( u)" 
11 Accumulate g := g + Vw 
12 Update parameters W := W + m 4g 
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24.2. MAXIMUM LIKELIHOOD TRAINING 


24.2.2.2 Persistent CD 


One variant of CD that sometimes performs better is persistent contrastive divergence (PCD) 
[Tie08; THO9; You99]. In this approach, a single MCMC chain with a persistent state is employed 
to sample from the EBM. In PCD, we do not restart the MCMC chain when training on a new 
datapoint; rather, we carry over the state of the previous MCMC chain and use it to initialize a new 
MCMC chain for the next training step. See Line 12 for some pseudocode. Hence there are two 
dynamical processes running at different time scales: the states x change quickly, and the parameters 
0 change slowly. 


Algorithm 39: Persistent MCMC-SGD for fitting an EBM 


Initialize parameters 0 randomly 
Initialize chains 21.5 randomly 
Initialize learning rate 7 
for t = 1,2,... do 
for x, in minibatch of size B do 
| ge = VoEo(£») 
for sample s = 1 : S do 
Sample Zs ~ MCMC (target = p(-|@), init = Z., nsteps = N) 
L Js = VoEo(&s) 


aoa fF WON 


aon 


B S x 
2 10 | g= -(5 4-19) (5 Xs- Gs) 
25 11 0 :=0 +ng: 
26 12 Decrease step size 7 


A theoretical justification for this was given in [You89], who showed that we can start the MCMC 
chain at its previous value, and just take a few steps, because p(x|0+) is likely to be close to p(x|@:_1), 
since we only changed the parameters by a small amount in the intervening SGD step. 


24.2.2.3 Other methods 


PCD can be further improved by keeping multiple historical states of the MCMC chain in a replay 
buffer and initialize new MCMC chains by randomly sampling from it [DM19b]. Other variants of 
CD include mean field CD [WH02], and multi-grid CD [Gao+ 18]. 

EBMs trained with CD may not capture the data distribution faithfully, since truncated MCMC 
can lead to biased gradient updates that hurt the learning dynamics [SMB10; FI10; Nij+19]. There 
are several methods that focus on removing this bias for improved MCMC training. For example, one 
line of work proposes unbiased estimators of the gradient through coupled MCMC [JOA17; QZW19]; 
and Du et al. [Du+20] propose to reduce the bias by differentiating through the MCMC sampling 
algorithm and estimating an entropy correction term. 
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24.3 Score Matching (SM) 


If two continuously differentiable real-valued functions f(x) and g(x) have equal first derivatives 
everywhere, then f(x) = g(x) + constant. When f(x) and g(x) are log probability density functions 
(PDFs) bee equal first derivatives, the normalization requirement (Equation (24.1)) implies that 
J exp(f(x))dx = f exp(g(x))dx = 1, and therefore f(x) = g(x). As a result, one can learn an EBM 
by ae matching the fest derivatives of its log-PDF to the first derivatives of the log-PDF 
of the data distribution. If they match, then the EBM captures the data distribution exactly. The 
first-order gradient function of a log-PDF is also called the score of that distribution. For training 
EBMs, it is useful to transform the equivalence of distributions to the equivalence of scores, because 
the score of an EBM can be easily obtained as follows: 


89(x) = Vx log pe(x) = -V xEo (x) (24.27) 


We see that this does not involve the typically intractable normalizing constant Zø. 

Let paata(x) be the underlying data distribution, from which we have a finite number of i.id. 
samples but do not know its PDF. The score matching objective [Hyv05] minimizes a discrepancy 
between two distributions called the Fisher divergence: 


Dr (Paata(x) || po(x)) = Epasts (Œ=) Vx log paata (x) — Vx log po(x)||* | - (24.28) 


The expectation w.r.t. Pdata(X), in this objective and its variants below, admits a trivial unbiased 
Monte Carlo estimator using the empirical mean of samples x ~ paata(x). However, the second term 
of Equation (24.28), Vx log paata(x), is generally impractical to calculate since it requires knowing 


=° the PDF of paata(x). We discuss a solution to this below. 


24.3.1 Basic score matching 


2 Hyvärinen [Hyv05] shows that, under certain regularity conditions, the Fisher divergence can be 


rewritten using integration by parts, with second derivatives of E(x) replacing the unknown first 


32 derivatives of paata(x): 


1 OEg(x , FEo(x) 
Dp(Paata(X) || Pa(X)) = Epasea( | > (4 2) 4 5— | + constant (24.29) 
2 7 Ox; Ox; 
1 
= Eat 5 lis0()IP + 1(dxs0(09)| + constant (24.30) 


40 where d is the dimensionality of x, and Jxs@(x) is the Jacobian of the score function. The constant 
41 does not affect optimization and thus can be dropped for training. It is shown by [Hyv05] that 
42 estimators based on Score Matching are consistent under some regularity conditions, meaning that 
43 the parameter estimator obtained by minimizing Equation (24.28) converges to the true parameters 


in the limit of infinite data. See Figure 24.3 for an example. 
An important downside of the objective Equation (24.30) is that it takes O(d?) time to compute 


46 the trace of the Jacobian. For this reason, the implicit SM formulation of Equation (24.30) has only 
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24.3. SCORE MATCHING (SM) 


been applied to relatively simple energy functions where computation of the second derivatives is 
tractable. 

Score Matching assumes a continuous data distribution with positive density over the space, but it 
can be generalized to discrete or bounded data distributions [Hyv07b; Lyu12]. It is also possible to 
consider higher-order gradients of log-PDFs beyond first derivatives [PDL+12]. 


24.3.2 Denoising Score Matching (DSM) 


The Score Matching objective in Equation (24.30) requires several regularity conditions for log paata(x), 
e.g., it should be continuously differentiable and finite everywhere. However, these conditions may not 
always hold in practice. For example, a distribution of digital images is typically discrete and bounded, 
because the values of pixels are restricted to the range {0,1,--- , 255}. Therefore, log paata(x) in 
this case is discontinuous and is negative infinity outside the range, and thus SM is not directly 
applicable. 

To alleviate this, one can add a bit of noise to each datapoint: x = x + €. As long as the noise 
distribution p(e€) is smooth, the resulting noisy data distribution g(x) = f q(X | X)Paata(x)dx is also 
smooth, and thus the Fisher divergence Dr(q(x) || pe(X)) is a proper objective. [KL10] showed that 
the objective with noisy data can be approximated by the noiseless Score Matching objective of 
Equation (24.30) plus a regularization term; this regularization makes Score Matching applicable to 
a wider range of data distributions, but still requires expensive second-order derivatives. 

[Vin11] proposed an elegant and scalable solution to the above difficulty, by showing that: 


x ; 1 m = 
Dr(a) | pol) = Bac {5 [Vx 108 po(3) — Yx oga] (24.31) 
1 
= E(x.) E || Vx log po (X) — Vx log ahol] + constant (24.32) 
t a&i] 
= 3 q(x,X) 89(X) + ere aa (24.33) 
2 


where se(X) = Vx log pe(x) is the estimated score function, and 


(x- 2) 


Vx log q(X|x) = Vx log N (ž|x, 071) = —, 


: (24.34) 


The expectation can be approximated by sampling from paata(x) and then sampling the noise term x. 
(The constant term does not affect optimization and can be ignored without changing the optimal 
solution.) 

This estimation method is called Denoising Score Matching (DSM) by [Vin11]. Similar 
formulations were also explored by Raphan and Simoncelli [RS07; RS11] and can be traced back to 
Tweedie’s formula (Supplementary Section 3.3) and Stein’s Unbiased Risk Estimation [Ste81]. 


24.3.2.1 Difficulties 


The major drawback of adding noise to data arises when paata(X) is already a well-behaved distribution 
that satisfies the regularity conditions required by Score Matching. In this case, Dr(q(x) || pe(x)) 4 
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Dr (Paata(x) || pe(x)), and DSM is not a consistent objective because the optimal EBM matches 
the noisy distribution q(x), not paata(x). This inconsistency becomes non-negligible when q(x) 
significantly differs from paata(x). 

One way to attenuate the inconsistency of DSM is to choose q % Paata, i.e., use a small noise 
perturbation. However, this often significantly increases the variance of objective values and hinders 
optimization. As an example, suppose q(x | x) = N(x | x,07I) and ø ~ 0. The corresponding DSM 
objective is 


2 z ; : 1 jz 2 
Dr(a(X) || pe()) = “paata (x) Ez~N (0,1) E | z + Vx log pe(x + az) 


1 Õ&lz® 2 


2N 4 
w=1 


K 


+ Vx log pe(x + oz®) (24.35) 


o 


2 
where {xO PE paata (x), and {244 "AS (0,1). When o — 0, we can leverage Taylor series 
expansion to rewrite the Monte Carlo estimator in Equation (24.35) to 

i) 112 
= i 2 0y log pọ (x®) + [eM lle + constant (24.36) 
2N &— |o * o? f ` 
When estimating the above expectation with samples, the variances of (2")"V, log pe(x)/o and 
Izol /o? will both grow unbounded as ø — 0 due to division by ø and ø?. This enlarges the 
variance of DSM and makes optimization challenging. Various methods have been proposed to reduce 
this variance (see e.g., [Wan+20d]). 


24.3.3 Sliced Score Matching (SSM) 


— By adding noise to data, DSM avoids the expensive computation of second-order derivatives. However, 


= as mentioned before, the optimal EBM that minimizes the DSM objective corresponds to the 


SiS e e is 
WIN | lo o 


= distribution of noise-perturbed data q(x), not the original noise-free data distribution paata(x). In 
— other words, DSM does not give a consistent estimator of the data distribution, i.e., one cannot 
— directly obtain an EBM that exactly matches the data distribution even with unlimited data. 


Sliced Score Matching (SSM) [Son+19] is one alternative to Denoising Score Matching that is 


— both consistent and computationally efficient. Instead of minimizing the Fisher divergence between 


e Je Jẹ Je Jẹ Je Je TR Jw lw jw jw j% jw 


= two vector-valued scores, SSM randomly samples a projection vector v, takes the inner product 
= between v and the two scores, and then compares the resulting two scalars. More specifically, Sliced 
— Score Matching minimizes the following divergence called the sliced Fisher divergence: 


1 
Ds (Pasta) 261%) = Epanst Ent) [5 (0 Vx 108 Panta) = VTV logpolx)?], (2487) 


4° where p(v) denotes a projection distribution such that E,(,)[vv"] is positive definite. Similar to 
= Fisher divergence, sliced Fisher divergence has an implicit form that does not involve the unknown 
“= Vx log paata(x), which is given by 


d 2 d d 
, ; 1 OE@(x O07 Eg (x 

DsFr(Daata(X)||Pe(x)) = Epsara(x)Epv) 3 5 ( Lu) + 5 5 EETA o ny, + C. (24.38) 
a = = a J 


i=l 
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24.3. SCORE MATCHING (SM) 


All expectations in the above objective can be estimated with empirical means, and again the 
constant term C can be removed without affecting training. The second term involves second-order 
derivatives of E(x), but contrary to SM, it can be computed efficiently with a cost linear in the 
dimensionality d. This is because 


d 


d d 
OPEe(x) yn 3 Ee (x) 
22 Ot, Mug Oa; (> ax; w) Vis (24.39) 


=F (x) 


where f(x) is the same for different values of i. Therefore, we only need to compute it once with 
O(d) computation, plus another O(d) computation for the outer sum to evaluate Equation (24.39), 
whereas the original SM objective requires O(d?) computation. 

For many choices of p(v), part of the SSM objective (Equation (24.38)) can be evaluated in closed 
form, potentially leading to lower variance. For example, when p(v) = M (0, I), we have 


: oh dE9(x) N? 1 xÅ f Cla 
“pants ) Ep(v) 33 ( TT n) = Epaseale) 3 Da (24.40) 


i=l 


and as a result, 


T (0E6(x)\? aS Eol 
~ ~ (x ol 
DsF (Paata (X) || pe (x)) = Dpaata (x) “v~N (0,I) D ( Ox; ) H 5 xð Tepes +O 


2 =l y= 


(24.41) 


1 
EE » [5 Cerne yr +o" (a (24.42) 


where J = J,89(ax). (Note that Jv can be computed using a Jacobian vector product operation.) 

The above objective Equation (24.41) can also be obtained by approximating the sum of second-order 
gradients in the standard SM objective (Equation (24.30)) with the Hutchinson trace estimator [Ski89; 
Hut89; Mey+21]. It often (but not always) has lower variance than Equation (24.38), and can 
perform better in some applications [Son-+19]. 


24.3.4 Connection to Contrastive Divergence 


Though Score Matching and Contrastive Divergence (Section 24.2.2) are seemingly very different 
approaches, they are closely connected to each other. In fact, Score Matching can be viewed as a 
special instance of Contrastive Divergence in the limit of a particular MCMC sampler [Hyv07a]. 
Moreover, the Fisher divergence optimized by Score Matching is related to the derivative of KL 
divergence [Cov99], which is the underlying objective of Contrastive Divergence. 

Contrastive Divergence requires sampling from the Energy-Based Model €9(x), and one popular 
method for doing so is Langevin MCMC. Recall from Section 24.2.1 that given any initial data point 
x’, the Langevin MCMC method executes the following 


xttl yak 5 VxEo(x*) +e, (24.43) 
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iteratively for k = 0,1,-.- ,K — 1, where z* ~ N (0, T) and e€ > 0 is the step size. 
Suppose we only run one-step Langevin MCMC for Contrastive Divergence. In this case, the 
gradient of the log-likelihood is given by 


paata (x) |W 0 log po (x)] = —Epgara(x)[VoFo(X)] + Exnpe(x) [VeoEo(x)] 


; : e 
ia V S [VoEo(x)] + Epo (x),20N (0,1) [Vo (x — z Vx (x) +é z) 


| . (24.44) 


After Taylor series expansion with respect to e€ followed by some algebraic manipulations, the above 
equation can be transformed to the following (see Hyvarinen [Hyv07al) 


“VoDrPaaa() || po) + 0(€) (24.45 


When c€ is sufficiently small, it corresponds to the re-scaled gradient of the Score Matching objective. 

In general, Score Matching minimizes the Fisher divergence Dr(paata(x) || pe(x)), whereas 
Contrastive Divergence minimizes an objective related to the KL divergence Dkr (Paata(X) || pe(x)), 
as shown in Equation (24.19). The above connection of Score Matching and Contrastive Divergence 
is a natural consequence of the connection between those two statistical divergences, as characterized 
by de Bruijin’s identity [Cov99; Lyul2]: 


© Drrlal®) | poa(®) = -3Dr (4) || po). 


Here q(X) and pọ +(X) denote smoothed versions of paata(x) and pọ(x), resulting from adding 


26 Gaussian noise to x with variance t; i.e., x ~ N(x,tI). 


= 24.3.5 Score-Based Generative Models 
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30 One typical application of EBMs is creating new samples that are similar to the training data. Towards 
31 this end, we can first train an EBM with Score Matching, and then sample from it with MCMC 
32 approaches. Many efficient sampling methods for EBMs, such as Langevin MCMC, rely on just the 
33 score of the EBM (see Equation (24.18)). In addition, Score Matching objectives (Equations (24.30), 


(24.33) and (24.38)) depend solely on the scores of EBMs. Therefore, we only need a model for 


35 the score when training with Score Matching and sampling with score-based MCMC, and do not 
36 have to model the energy explicitly. (Note that this loses the fact that the score is derived from the 
37 derivative of the scalar energy, which can be a useful constraint.) By building such a score model, we 
38 save the gradient computation of EBMs and can make training and sampling more efficient. These 
39 kind of models are named score-based generative models [SE19; SE20b; Son+21]. (See also 
40 Section 24.3.5 for a discussion of the related approach to denoising diffusion probabilistic models.) 


We can optimize the score function sg(a) using score matching, sliced score matching, or denoising 


42 score matching. Figure 24.3 gives a simple example in 2d. In Figure 24.3a, we show the swiss roll 
43 dataset. We estimate the score function by fitting an MLP with 2 hidden layers, each with 128 


hidden units. In Figure 24.3b, we showed the output of the network after training for 10,000 steps 


45 of SGD. We see that there are no major false negatives (since wherever the density of the data is 
46 highest, the gradient field is zero), but there are some false positives (since some regions of zero 
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gradient do not correspond to data regions). The comparison of the predicted outputs with the 
empirical data density is shown more clearly in Figure 24.3c. In Figure 24.3d, we show some samples 
from the learned model. 


24.3.5.1 Adding noise at multiple scales 


In general, score matching can have difficulty when there are regions of low data density. To see this, 
Suppose Paata(X) = ™po(x) + (1 — 7)pi(x). Let So := {x | po(x) > 0} and S; := {x | pi(x) > 0} be 
the supports of po(x) and p;(x) respectively. When they are disjoint from each other, the score of 
Pdata(X) is given by 


Vx log po (x), XE So 


24.46 
Vx log pı (x), XE Si, l ) 


Vx log Pdata (X) = 
which does not depend on the weight m. Hence score matching cannot correctly recover the true 
distribution. Furthermore, Langevin sampling will have difficulty traversing between modes. (In 
practice this will happen even when the different modes only have approximately disjoint supports.) 

Song and Ermon [SE19; SE20b] and Song et al. [Son+21] overcome this difficulty by perturbing 
training data with different scales of noise. Specifically, they use 


qo (|£) = N(#|x, 07I) (24.47) 
o(&) = J Paata (£)qo (|x)dæ (24.48) 


For a large noise perturbation, different modes are connected due to added noise, and the estimated 
weights between them are therefore accurate. For a small noise perturbation, different modes are 


2 more disconnected, but the noise-perturbed distribution is closer to the original unperturbed data 
27 distribution. Using a sampling method such as annealed Langevin dynamics [SE19; SE20b; Son+21] 
28 or diffusion sampling [SD+15a; HJA20; Son+21], we can sample from the most noise-perturbed 
22 distribution first, then smoothly reduce the magnitude of noise scales until reaching the smallest one. 
22 This procedure helps combine information from all noise scales, and maintains the correct estimation 
32 of weights from higher noise perturbations when sampling from smaller ones. 


In practice, all score models share weights and are implemented with a single neural network 


233 conditioned on the noise scale; this is called a Noise Conditional Score Network, and has the 
34 form sg(x,o). Scores of different scales are estimated by training a mixture of Score Matching 
33 objectives, one per noise scale. If we use the denoising score matching objective in Equation (24.33), 
“ we get 


1 S z 
L(0; 0) = Eq(x,x) $ |Vx log po(x, c) — Vx log apo] (24.49) 


; (24.50) 


la : 
= 2 “Ddata(X) sweem | s 


43 These losses are combined in a weighted fashion using 


L(0; 01:1) tO Ne) (0;)L£(0; 0;) (24.51) 
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24.4. NOISE CONTRASTIVE ESTIMATION 


where we choose 01 > 02 >--- > o, and the weighting term satisfies \(a) > 0. 

Empirically Song and Ermon [SE19] found that setting A(o;) = o; works well. To set the oo, 
we choose cı small enough such that po, (£) © paata(a), and we choose cz large enough that 
Po, (£) ~ N (x|0, o} I). (See [SE20b] for further discussion.) In [Son+21], this technique is extended 
to an infinite number of noise levels, by working with stochastic differential equations in continuous 
time. 


24.4 Noise Contrastive Estimation 


Another principle for learning the parameters of EBMs is Noise Contrastive Estimation (NCE), 
introduced by [GH10]. It is based on the idea that we can learn an Energy-Based Model by contrasting 
it with another distribution with known density. 

Let paata(x) be our data distribution, and let pn(x) be a chosen distribution with known density, 
called a noise distribution. This noise distribution is usually simple and has a tractable PDF, like 
N (0, I), such that we can compute the PDF and generate samples from it efficiently. Strategies exist 
to learn the noise distribution, as referenced below. Furthermore, let y be a binary variable with 
Bernoulli distribution, which we use to define a mixture distribution of noise and data: py data(X) = 
p(y = 0)pn(x) + p(y = 1)paata(x). According to Bayes’ rule, given a sample x from this mixture, the 
posterior probability of y = 0 is 


Pn,data(x | y = 0)p(y = 0) Pn(X) 
n,data(y = 0 | x) = > = 24.52 
miat y | ) Dn,data(X) Dn (x) F VPdata (X) ( ) 
where v = p(y = 1)/p(y = 0). 
Let our Energy-Based Model pg(x) be defined as: 
po(x) = exp(—Eo (x))/Zo (24.53) 


Contrary to most other EBMs, Zgo is treated as a learnable (scalar) parameter in NCE. Given this 
model, similar to the mixture of noise and data above, we can define a mixture of noise and the 
model distribution: pn o (x) = p(y = 0)pn(x) + p(y = 1)pe(x). The posterior probability of y = 0 
given this noise/model mixture is: 


Dn(X) 


Pa(x) + vpo(x) (24.54) 


Puo(y = 0| x) = 


In NCE, we indirectly fit pg(x) to paata(x) by fitting pe (y | x) to Pn data(y | x) through a standard 
conditional maximum likelihood objective: 


6° = argmin Ep, a, 09 [Daer (Pn amalu |5) | Pao 139) (24.55) 
= Ao DPn data (x,y) [log pn o (y | x)], (24.56) 


which can be solved using stochastic gradient ascent. Just like any other deep classifier, when the 
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model is sufficiently powerful, py.g*(y | x) will match py data(y | x) at the optimum. In that case: 


Pn,0* (y =0 | x) = Dn data (Y =0 | x) (24.57) 
Pn(X) _ Pn(X) 

Pn(x) + vpo (x)  Pa(X) + VPaata(x) oe) 

< Pe (x) = Daata(X) (24.59) 


Consequently, Ee» (x) is an unnormalized energy function that matches the data distribution paata(x), 
and Zø» is the corresponding normalizing constant. 

As one unique feature that Contrastive Divergence and Score Matching do not have, NCE provides 
the normalizing constant of an Energy-Based Model as a by-product of its training procedure. When 
the EBM is very expressive, e.g., a deep neural network with many parameters, we can assume 
it is able to approximate a normalized probability density and absorb Zə into the parameters 
of Eg(x) [MT12], or equivalently, fixing Zə = 1. The resulting EBM trained with NCE will be 
self-normalized, %.e., having a normalizing constant close to 1. 

In practice, choosing the right noise distribution p,(x) is critical to the success of NCE, especially 
for structured and high-dimensional data. As argued in Gutmann and Hirayama [GH12], NCE 
works the best when the noise distribution is close to the data distribution (but not exactly the 
same). Many methods have been proposed to automatically tune the noise distribution, such 
as Adversarial Contrastive Estimation [BLC18], Conditional NCE [CG18] and Flow Contrastive 
Estimation [Gao+20]. NCE can be further generalized using Bregman divergences (Section 5.1.8), 
where the formulation introduced here reduces to a special case. 


24.4.1 Connection to Score Matching 


Noise Contrastive Estimation provides a family of objectives that vary for different p,(x) and v. This 
flexibility may allow adaptation to special properties of a task with hand-tuned p,(x) and v, and 
may also give a unified perspective for different approaches. In particular, when using an appropriate 
Pn(x) and a slightly different parameterization of pn,e(y | x), we can recover Score Matching from 


31 NCE [GH12]. 


Specifically, we choose the noise distribution p,(x) to be a perturbed data distribution: given a 
small (deterministic) vector v, let p(X) = Paata(X — v). It is efficient to sample from this p,(x), since 
we can first draw any datapoint x’ ~ paata(x’) and then compute x = x’ + v. It is, however, difficult 
to evaluate the density of p,(x) because paata(X) is unknown. Since the original parameterization of 
Pn,o(y | x) in NCE (Equation (24.54)) depends on the PDF of p,(x), we cannot directly apply the 
standard NCE objective. Instead, we replace py(x) with pg(x — v) and parameterize pye(y = 0 | x) 
with the following form 


pe(x —v) 
poe(x) + pe(x — v) 


Pno(y = 0 | x) := (24.60) 


== In this case, the NCE objective (Equation (24.56)) reduces to: 


0” = argminE,,,,.(x) [log(1 + exp(Eo (x) — €9(x — v)) + log(1 + exp(€(x) — €6(x + v))] 
@ 
(24.61) 
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24.5. OTHER METHODS 


At 0*, we have a solution where: 


Pn,o* (y = 0 | xX) = Pn,aata(y = 0 | x) (24.62) 
Po* (x = v) == Paata(X = v) 
Po* (x) F po~* (x = v) Daata(X) + Ddata(X — v) 


(24.63) 


which implies that pg«(x) = paata(X), i.e., our model matches the data distribution. 
As noted in Gutmann and Hirayama [GH12] and Song et al. [Son+19], when ||v||, + 0, the NCE 
objective Equation (24.56) has the following equivalent form by Taylor expansion 


a on OE (x 0° Eo (x 2 
argmin Eases) |3 D (ee ve bn) PS pee! TAF Fane + 2log 2+ o(||v||5). (24.64) 


= i=l j=1 


Comparing against Equation (24.38), we immediately see that the above objective equals that of 
SSM, if we ignore small additional terms hidden in o(||v 13) and take the expectation with respect to 
v over a user-specified distribution p(v). 


24.5 Other Methods 


Aside from MCMC-based training, Score Matching and Noise Contrastive Estimation, there are 
also other methods for learning EBMs. Below we briefly survey some examples of them. Interested 
readers can learn more details from references therein. 


24.5.1 Minimizing Differences/Derivatives of KL Divergences 


The overarching strategy for learning probabilistic models from data is to minimize the KL divergence 
between data and model distributions. However, because the normalizing constants of EBMs are 
typically intractable, it is hard to directly evaluate the KL divergence when the model is an EBM (see 
the discussion in Section 24.2.1). One generic idea that has frequently circumvented this difficulty 
is to consider differences/derivatives of KL divergences. It turns out that the unknown partition 
functions of EBMs are often cancelled out after taking the difference of two closely related KL 
divergences, or computing the derivatives. 

Typical examples of this strategy include minimum velocity learning [Mov08; Wan+20d], minimum 
probability flow [SDBD11] and minimum KL contraction [Lyul1], to name a few. In minimum 
velocity learning and minimum probability flow, a Markov chain is designed such that it starts 
from the data distribution paata(x) and converges to the EBM distribution pe(x) = e7 Eo (x) /Zo. 
Specifically, the Markov chain satisfies po(x) = paata(x) and px.(x) = pe(x), where we denote by 
pr(x) the state distribution at time t > 0. 

This Markov chain will evolve towards pg(x) unless Pdata(x) = pe(x). Therefore, we can fit the 
EBM distribution pg(x) to paata(x) by minimizing the modulus of the “velocity” of this evolution, 
defined by 


sy Dex (Pe) | po(x))] or gp Dax (Panel) I eC) | (24.65) 
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in minimum velocity learning and minimum probability flow respectively. These objectives typically 
do not require computing the normalizing constant Zø. 
In minimum KL contraction [Lyul1], a distribution transformation ® is chosen such that 


Dx (p(x) || (x)) > Dr (P{p(x)} || Ofaex)}) (24.66) 


with equality if and only if p(x) = q(x). We can leverage this ® to train an EBM, by minimizing 


Dgr (Paata (X) || po(x)) — DrL(® {Paata (x)} || &{pe(x)}). (24.67) 


This objective does not require computing the partition function Zg whenever ® is linear. 

Minimum velocity learning, minimum probability flow, and minimum KL contraction can all be 
viewed as generalizations to Score Matching and Noise Contrastive Estimation [Mov08; SDBD11; 
Lyul11]j. 


24.5.2 Minimizing the Stein Discrepancy 


We can train EBMs by minimizing the Stein discrepancy, defined by 


Dgtein(Paata(X) || pe(x)) := ap paata (x) [Vx log po (x)" £ (x) + trace(Vxf(x))], (24.68) 
€ 


where F is a family of vector-valued functions, and Vxf(x) denotes the Jacobian of f(x). With 
some regularity conditions [GM15; LLJ16], we have Dg(paata(x) || po(x)) > 0, where the equality 


~ holds if and only if paata(x) = pe(x). Similar to Score Matching (Equation (24.30)), the objective 


Equation (24.68) only involves the score function of pọ (x), and does not require computing the EBM’s 


~ partition function. Still, the trace term in Equation (24.68) may demand expensive computation, 
~~ and does not scale well to high dimensional data. 


There are two common methods that sidestep this difficulty. Gorham and Mackey [GM15] and 


~. Liu, Lee, and Jordan [LLJ16] discovered that when F is a unit ball in a Reproducing Kernel Hilbert 
~~ Space (RKHS) with a fixed kernel, the Stein discrepancy becomes kernelized Stein discrepancy, where 
~— the trace term is a constant and does not affect optimization. Otherwise, trace(Vxf(x)) can be 
~ approximated with the Skilling-Hutchinson trace estimator [Ski89; Hut89; Gra+20c]. 


= 24.5.3 Adversarial Training 


40 Recall from Section 24.2.1 that when training EBMs with maximum likelihood estimation (MLE), 
41 we need to sample from the EBM per training iteration. However, sampling using multiple MCMC 
42 steps is expensive and requires careful tuning of the Markov chain. One way to avoid this difficulty is 
43 to use non-MLE methods that do not need sampling, such as Score Matching and Noise Contrastive 


Estimation. Here we introduce another family of methods that sidestep costly MCMC sampling by 


45 learning an auxiliary model through adversarial training, which allows fast sampling. 


Using the definition of EBMs, we can rewrite the maximum likelihood objective by introducing a 
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24.5. OTHER METHODS 


variational distribution ¢g(x) parameterized by @: 


rata llog Po(%)] = Epes) E08) — log Zo 
= Epaata (x) [E0 (x)| -log f e-ax 
eo £0(x) 
= Epas (x) [-€0(X)] tog f a0) > 
2 ce €0(x) 
S Epania ol-E0(%)] — / g(x) log Gy ox 
= Epausco [E060] ~ Bago [E00] ~ Hao), (24.69) 


where H(q¢(x)) denotes the entropy of gg(x). Step (i) is due to Jensen’s inequality. Equation (24.69) 
provides an upper bound to the expected log-likelihood. For EBM training, we can first minimize the 
upper bound Equation (24.69) with respect to gg(x) so that it is closer to the likelihood objective, 
and then maximize Equation (24.69) with respect to E€g(x) as a surrogate for maximizing likelihood. 
This amounts to using the following maximin objective 


max min Eyg(x)(E9(%)] — Epaua(29(E0(%)] — H (a09). (24.70) 


Optimizing the above objective is similar to training GANs (Chapter 26), and can be achieved by 
adversarial training. The variational distribution q(x) should allow both fast sampling and efficient 
entropy evaluation to make Equation (24.70) tractable. This limits the model family of q(x), and 
usually restricts our choice to invertible probabilistic models, such as inverse autoregressive flow 
(Section 23.2.4.3). See Dai et al. [Dai+19b] for an example on designing qg(x) and training EBMs 
with Equation (24.70). 

Kim and Bengio [KB16] and Zhai et al. [Zha+16] propose to represent g(x) with neural samplers, 
like the generator of GANs. A neural sampler is a deterministic mapping gg that maps a random 
Gaussian noise z ~ N (0, Į) directly to a sample x = gg(z). When using a neural sampler as q¢(x), 
it is efficient to draw samples through the deterministic mapping, but H(qg(x)) is intractable since 
the density of qg(x) is unknown. Kim and Bengio [KB16] and Zhai et al. [Zha+16] propose several 
heuristics to approximate this entropy function. Kumar et al. [Kum-+19c] propose to estimate 
the entropy through its connection to mutual information: H(qg(z)) = I(g¢(z),z), which can be 
estimated from samples with variational lower bounds [NWJ10b; NCT16b]. Dai et al. [Dai+19al] 
noticed that when defining pe(x) = po(x)e~©**) /Z@, with po(x) being a fixed base distribution, the 
entropy term —H(qg(x)) in Equation (24.70) can be replaced by Dxx(q¢(x) || po(x)), which can 
also be approximated with variational lower bounds using samples from gg(x) and po(x), without 
requiring the density of qg(x). 

Grathwohl et al. [Gra+20a] represent gg(x) as a noisy neural sampler, where samples are obtained 
via gg(z) + oe, assuming z,e ~ N(0,J). With a noisy neural sampler, VyH(q¢(x)) becomes 
particularly easy to estimate, which allows gradient-based optimization for the minimax objective in 
Equation (24.69). A related approach is proposed in Xie et al. [Xie+18], where authors train a noisy 
neural sampler with samples obtained from MCMC, and initialize new MCMC chains with samples 
generated from the neural sampler. This cooperative sampling scheme improves the convergence of 
MCMC, but may still require multiple MCMC steps for sample generation. It does not optimize the 
objective in Equation (24.69). 
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When using both adversarial training and MCMC sampling, Yu et al. [Yu+20] noticed that EBMs 
can be trained with an arbitrary f-divergence, including KL, reverse KL, total variation, Hellinger, 
etc.. The method proposed by Yu et al. [Yu+-20] allows us to explore the trade-offs and inductive 
bias of different statistical divergences for more flexible EBM training. 
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25 Diffusion models 


In this section, we consider a class of models which we call diffusion models or DMs, that includes 
denoising diffusion models [SD+15b], score-based generative models (see Section 24.3.5), 
variational diffusion models [HJ A20; Kin+21]), as well as methods based on stochastic differ- 
ential equations (sce e.g., [Son+21]). 

The basic insight is the following: it can be hard to convert noise into structured data, but it 
is easy to convert structured data into noise. In particular, we can use a forwards process or 
diffusion process to gradually convert the observed data ao into a noisy version ær by passing the 
data through T steps of a stochastic encoder q(a;|x,_1). After enough steps, we have xp ~ N (0,1), 
or some other convenient reference distribution. We then learn a reverse process to undo this, by 
passing the noise through T steps of a decoder pe (x+—1|x+) until we generate xo. See Figure 25.1 for 
an overall sketch of the approach. We give a brief summary of the method below. For more details, 
see the recent tutorial at https: //cvpr2022-tutorial-diffusion-models.github.io/. 


25.1 Variational diffusion models 


In this section, we present the VDM or variational diffusion model of [Kin+21]. The VDM 
is a special case of a multi-layer VAE (Chapter 21) where each latent layer æ, has the same size 
as the input £o, but has increasing amounts of noise added to it. Since the encoder only adds a 
small amount of noise at each stept, we can easily learn a decoder to invert this operation, and we 
don’t need to worry about posterior collapse (Section 21.4). Combining a sequence of such decoder 
steps lets us gradually map from unstructured noise to data. This is similar to a normalizing flow 
(Chapter 23), except we don’t require that the mapping function to be invertible. 


25.1.1 Encoder 


We define each step of the encoder to have the form 


q(xi|@s) = N (xsl 52s, 0751) (25.1) 


1. In [RHS22] they propose to diffuse the input by blurring it instead of adding Gaussian noise. This process can also 
be (approximately) inverted, and hence can be tackled using VDM-like methods. 
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», Lhe only learnable parameters for the encoder are the scalars a; and o;. We usually set a; = y 1 — oF 
-~ to ensure the noise is variance preserving. if we set a; = 1, we get the variance exploding 
- diffusion process of [SE20b]. 
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po(Xt—1|Xt) 
on — —— =. Op 


ioe 1) 


Figure 25.1: Denoising diffusion probabilistic model. The forwards diffusion process, q(x+|xt-1), implements 
the (non-learned) inference network; this just adds Gaussian noise at each step. The reverse diffusion process, 
pe(xr-1|@1), implements the decoder; this is a learned Gaussian model. From Figure 1 of [HJA20]. Used 
with kind permission of Ajay Jain. 


where the parameters Qis and Ots are defined below, and t = s + 1. If we have T layers, the overall 
encoder can be represented as follows: 


q(x1-7|£o) = Ia L,|X14-1) (25.2) 


Since the encoder is a linear Gaussian Markov chain, we can compute the marginal q(a;|ao) analyti- 
cally. It has the following form 


q(az|a0) = N(x: \orxo, 071) (25.3) 
where 

Qtjs = A4/ Os (25.4) 

Ois = 0; — O05 (25.5 


(See Section 25.1.1.2 for a derivation of these equations.) 


25.1.1.1 Signal to noise ratio 


2 


An alternative parameterization is in terms of the signal to noise ratio 
SNR(t) = a?/a? (25.6) 


which we assume is monotonically decreasing in t, so higher layers, with t > s, are more noisy (smaller 
SNR). We compute the SNR using 


SNR(t) = exp(—76(¢)) (25.7) 


== where y(t) is a monotonic neural network. (In the [Kin+21] paper, they use yg(t) = l(t) + 


l3(¢(l2 (l1 (t)))), where the linear layers l1, l2,l3 have positive weights, ¢ is the sigmoid nonlinearity, 


= layer 2 has 1024 units the other layers are scalar.) From this, we get 


o? = o(v@(t)),a? = o(-7¢(t)) (25.8) 
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25.1. VARIATIONAL DIFFUSION MODELS 


where ø is the sigmoid function. 
To ensure the computation of Ch, is numerically stable, we can rewrite it in terms of the log SNR, 


22 
—= ayo . . 
à: = log(a? /o?). We have e**—** = a? and hence we can rewrite the variance term as 
a 


2 


Tis =o — Sro = o? (1 — e™7òs) = go? (1 — expm(A; — às)) (25.9) 


where expm(u) = e” — 1, for which numerically stable implementations exist. 


25.1.1.2 Markov chain encoder 


The function q(æ:|£o) lets us map the input directly to any given noise level, but we can also 
sequentially add noise by using the conditional distribution 


qlærlæs) = N (Tilts £s, 07.1) (25.10) 
where 
Qtjs = O4/ As, Tijs =g? — a$ 575 (25.11) 


To see this, note that 


Lt = Ats Ls + Ot|s€t (25.12) 
= Qtjs (as£o ar Os€s) + Ot|sEt (25.13) 
= S asæo + Os €s + Otjs Et (25.14) 
s s x~ 
o1 


= Ha + 4/07 + ofe (25.15) 


where e ~ N (0, I) is the additive noise term M (0,071) + MN (0, 031) = N (0, (o? + 03)1), where the 
variance of the combined noise is 


o? +02 = Cals + Cis =0? (25.16) 


which follows from the definitions in Equation (25.11). 


25.1.2 Decoder 
For a finite number T of layers, the decoder has the form 


T-1 
po(o:r) = plær)| | | po (wsler+)) (25.17) 
t=0 
where p(ær) = N (ær|0,I) and po(x;|x141) is specified below. 
We would like the decoder to reverse the effect of the encoder, so we would like to set the 
likelihood to p(ao|zz) = q(£oļxı). Unfortunately reversing the forwards process exactly is intractable. 
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However, for small enough o@ we have p(xo|x1) x q(a1|"o) so we set p(xo|x1) = []yp(ralr1,4), 


where p(xa|©1,4) = Sie 


ee where we normalize over all possible values of xq. 
Ta > d 


For each step of the decoder, we would like pọ(£s|x+) to be close to q(£s|£+, £o), since this will 
minimize the KL divergence in the ELBO (see Section 25.1.3). By the Markov properties and Bayes 
rule, we have 


q(@1|@s,20)q(@s| x0) 
q(a+|ao) 


q(%s|Lz, £0) = (25.18) 


Since all the terms are conditionally Gaussian, we can use Bayes rule for linear Gaussian systems. 
Restating the results from Section 2.2.6.2, we have the following: 


q(as|#0) = N (£s|uz, U2) = N (zs|as£0, 051) (25.19) 
qlxi|£s) = N (£|W Ts, Uz) = N (£|ats £s, a7) 61) (25.20) 
q(@s|a@2,@0) = N(axs|M, 2) (25.21) 
1 1 
D = 3714+ W'E W = I+ ohy l = 6/1 (25.22) 
o2 Cis 
Tẹ; -1 —1 Qtjs Qs 
w= >[W E; ti +E; u] = El- t + z220] (25.23) 
t|s s 


We can simplify the above expressions as follows. First, for the posterior precision, we have 


2 2 2 
fon 2 
Pe a a E a (25.24) 
st ~~ ga" sg o202 — oo? : 
s t|s s~ t|s sS“ t|s 


= where we leveraged the fact that of, = 0? — aj,,03 from Equation (25.11). Hence the variance can 


= be written as 
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a OsOts = 
Cia = a =o7(1—¢*-**) (25.25) 


2 2 
= _ Fs %t|s ls | FsF%tls As ss, ASPs 25.26 
fis|e(Lt, Lo) = 2 3 et 2 70 = 3 ten 3 To ( . ) 
Cf Ois oa? o? oi O; 


Based on the above derivation, we choose the decoder pg(x,|x;) to have the same form as the 


40 encoder, except we replace the observed £o with the predicted value #9(a;;t), i.e., 


Po(&s|e1) = q(Ls|fbs|4(Lt), Fs |e) (25.27) 
where 
Deel) = Dst (Tt, Lo = Lo (x1; t)) (25.28) 
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25.1. VARIATIONAL DIFFUSION MODELS 


Rather than predicting the clean image £o from the noisy version æ+, we can equivalently predict the 
noise € that was added, defined by 


ĉo (x4; t) = (x — Qaro (x4; t))/ot (25.29) 


In this case, we can write the mean of the decoder as 


selt) = : (« atl caern) (25.30) 


Qt|s 


We can interpret this as predicting the noise, then subtracting it off from æ, to generate the mean of 
the less noisy version £s. 

Finally, we can make a connection to score-based generative models (Section 24.3.5) by noting 
that the score function has the form 


1 
Var log q(xz|20) = Var log N (x;|a1X0, oI) = Ve log exp(- zzz l|: = arxoll?) (25.31) 
t 
1 1 
= -gg Vm (wie — 2 £l 29) = as — x£) (25.32) 
Hence we can rewrite the predicted noise as ĉe (æ+; t) = —s@(a1;t) where 
so(#1;t) = (ar®o (az; t) — a) /07 (25.33) 


Thus we can sample from the model by adding noise in the direction of the score function: 


TLs = (2: + 071,80 (21; t) + Ts|t JE (25.34) 


ts 


This is a form of ancestral sampling, and is similar to the Langevin sampling method used in 
score-based generative models (see Section 24.3.5.1). 


25.1.3 Model fitting 


To fit the model, we minimize the cross entropy between the empirical distribution pp and the model 
p, or equivalently we maximize the expected log marginal likelihood: 


J= [exo pp(o) log pe(xo) (25.35) 


This expression is intractable, but we can derive a variational lower bound (VLB), also called the 
evidence lower bound (Section 10.1.2). 


25.1.3.1 Deriving the ELBO 


We can write the marginal probability of a single observation as follows: 


Pe(o) = f dere po(zo:T) = | iævraærr leo) ZE (25.36) 
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By Jensen’s inequality we get the following evidence lower bound on the log marginal likelihood, 


which we want to maximize: 


L(xo) = log J izraela ts 


ql£ı:r|£0) 
pe( Lt 1|&z) 
dx 21:7|Lo) log | r(x A 
> J dexraerrlzo) s(x ol leje- 2) 


L (arr |e) 
=E, foeter) +% log ea t 4 L(xo) 


(25.37) 


(25.38) 


(25.39) 


We now discuss how to compute the terms in the ELBO. By the Markov property we have 


q(xil£t-1) = q(£t|£:-1, £0), and by Bayes rule, we have 


q(Li—1|£t, Lo)q(#z|x0) 
q(x:—1|£0) 


Plugging Equation (25.40) into the ELBO we get 


ql(x£il£t—1, Zo) = 


T 
L(a@9) = Eg(æ1.r) log (ar) + ` log 


t=2 


T 
x pol Tt— 1læ:) Lt 1læo) 
= Evin l + J log ———= + l 
q(æı:r) ogr(ær) = 26 q(£i—1l£t, Lo) e r zil£o) 


* 


The term marked * is a telescoping sum, and can be simplified as follows: 


* = log q(x7_1|@o) + log q(ær—2|£0) +--+ + log q(x1|x0) 


— logq(ær|x£o) — log q(ar_i|ao) — --- — log q(a2|x0) 
= log q(x1|x0) — log q(xr|a0) 
3° Hence we want to minimize 
~ xl Lt jæ t) 
L =-E l ———— 4+] 
(£o) ated |108 e+ be eee) + log pe(xo|x1) 
= Dxu (q(ær|zxo) l nae q(x1|x0)|— log pe(#o|#1)| + Lp(xo) 
N A l e—a D a 


prior loss reconstruction loss diffusion loss 


where the diffusion loss is given by 


= > Dua (q(%t-1|%t, £0) || po(a+-1|@2)) 
eo” 
t=2 KL; 
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polxi—ı|£:) qlæ:-1ı|£0) lo po(x£o|æ1) 
q(£i-1|£t, £0) ql(æi|xo) q(a#1|o) 


(25.40) 


(25.41) 


po(xo|a#1) 
q(x1|£o) 


(25.42) 


(25.43) 
(25.44) 
(25.45) 


(25.46) 


(25.47) 


(25.48) 
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25.1. VARIATIONAL DIFFUSION MODELS 


The prior loss term is a constant, since q has no free parameters. The diffusion loss can be optimized 
using the reparameterization trick. We discuss how to compute the KL terms in the diffusion loss 
below. 


25.1.3.2 Deriving the KL terms 


Recall that p(z,|z,) and q(zs|£t, £o) are both Gaussians, and the KL between two Gaussians has a 
closed form. In particular, from Equation (5.80), we have 


Dri (N (|My, 21) || N (zlu, 22) 


E 22) + (a — oa) E "(a — a) D +105 (Sts) | (25.49) 


Since the variances are equal, the KL simplifies to 


2 1 7 
= => ||@0 — ĉe(x:;t)|| (25.50) 


D s 5 s = h3 
KL (9(Zs|Zt, Lo) || p(zs|z:)) 2 26714 


1 z > 
er, |s (2e 20) — fisz) 


where sie = SNR(s) — SNR (¢). 


25.1.3.3 Weighting terms 


It is possible to train the model using a loss function that is different to the negative log likelihood. 
For example, we may want to put more weight on larger noise levels, to discourage the model 
from wasting capacity modeling low-level detais. We can do this by adding a weighting factor w(t) 
in front of each KL term in the diffusion loss. In particular, suppose we rewrite the KL term in 
Equation (25.50) as 


KL; = w(t) ||£o — ĉo (x+; t)|| (25.51) 
Now suppose we have a; = o; = 1. Then Equation (25.29) simplifies to 

€9(a1;t) = £i — olL; t) (25.52) 
where £, = £o + €, where e, ~ N(0,1). Hence 

KL; = w(t) ||£o — a — ĉo (ær; t)|| = w(t) llex — ĉo (ae; t)|| (25.53) 
rag is D very simple objective to optimize, and often results in better perceptual quality (see e.g., 
HJA20]). 


25.1.3.4 Stochastic estimator for deep models 


For very deep models, it is expensive to compute all T diffusion terms. Fortunately we can compute 
an unbiased approximation by simply sampling a layer t ~ Unif(1, T), and noise term e, ~ N (0, I), 
and then computing 


5 T 


Lp(zo) = p Bein N (0,1),t~Unit(1, T) (SNR (t — 1) — SNR (¢)) les — ĉo (æ+; t)|| (25.54) 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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25.1.3.5 Infinite depth T —- oo 


One can show that using more layers leads to a lower diffusion loss, while keeping the SNR function 
fixed. This suggests considering an infinitely deep model. To model this, we assume t € [0, 1], and we 
denote the latent variables by z+, where Zp is the least noisy version of the data, and zı is the most 
noisy. For a finite number of layers, we discretize time uniformly into T timesteps of width 7 = 1/T. 
Defining s(i) = (i — 1)/T and t(i) = i/T, the overall encoder can then be represented as follows: 


T 
q(zo:1|£0) = q(zo|£0) IEEGEG (25.55) 
i=1 
and the decoder becomes 
T 
po(£o, 20:1) = plizi )[] [polza )lpEolzo) (25.56) 
i=1 


We can calculate the limit of the diffusion loss as T — oo as follows. Let 7 — 1/T be the step size. 
Then we have 


i SNR(t — 7) — SNR(t . 
£ (0) = sBanw(on) nunnery | leo- Bole (25.57) 
As T —> 0, we have T — oo, and the loss becomes 
l "n 
Lp(zxo) = — 5 Een (0,1) i~Unif(1,T) [SNR (e) |x — @o (ze; t)i (25.58) 
1 = A 
a tenon) | SNR’ (t) ||£o — ĉo (zt; )\I5 dt (25.59) 
0 


27 where SNR’(t) is the derivative of the SNR function. 


29 25.1.4 Connection to DDPM 


Sie IS TÈ JS IS IS JS 1S 1S S 2 S E le iS e |S 
[s IO lo te jo N Ie JO lo Iœ IN [Om Io Th [ow N e Io 


31 The denoising diffusion probabilistic model of [HJA20] is a precursor to VDM that uses a 


slightly different parameterization. In particular, it assumes 


qlæilær1) = N (2V1 — Brass, Bil) (25.60) 
From this one can show? that 

q(@+|20) = N (x| Vazo, (1 — I) (25.61) 
where œ = 1 — b; and aq = IÉ as. The decoder is chosen to have the following form: 

pol(£i—il£t) =N (£i—1|Holz£r, t), No (xz, t)) (25.62) 


-~ where 


Bt 


zo(£ı,t 
aan” 


Lt 


Holz, t) = a (25.63) 


= In practice, the VDM parameterization is more stable and easier to train. 


46 2. See e.g., https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ 
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25.1. VARIATIONAL DIFFUSION MODELS 


t= 0.5 t=1 
J -i sa 
Encoder q i s : oot 
is 
aes 
B i 
“ane 
ole. o 
Decoder p ? 
residual € ° 


Figure 25.2: Continuous time VDM diffusion model on 2d swiss roll. Top row: samples from the encoder 
q at 3 different time slices. Middle row: samples from the generator pe at 3 different time slices. Bottom 
row: Visualization of the predicted error term €o(a+;t) = xt — £e(a1;t). We see that this resembles the 
score function, in that it points towards regions of high data density. Adapted from Figure 1 of [SD+15b]. 
Generated by vdm_ 2d.ipynb. Used with kind permission of Durk Kingma. 


25.1.5 2d Example 


In Figure 25.2, we show an example of a VDM model fit to 1024 samples from the 2d Swiss roll 
dataset. The model uses 3 dense layers with swish activations, and is fit to a discretized version of 
the data (using 256 values for each point) with a softmax likelihood. Adding the Fourier features 
described in [Kin+21] further improves the results. The model uses T = 200 time steps, and takes 12 
minutes to train for 20,000 SGD steps on a TPU. 


25.1.6 Image generation 


Diffusion models are often used to generate images. We show a simple example in Figure 25.3. By 
training big models (billions of parameters) for a long time (days) on lots of data (millions of images), 
diffusion models can be made to generate very high quality images (see e.g., Figure 20.3). The 
most common architecture is based on the U-net model [RFB15], originally proposed for semantic 
segmentation. This is augmented with attention layers, and is then trained to predict a clean version 
of the image from a noisy version. Recently, the most impressive results of image generation come 
from conditional diffusion models, where guidance is provided about what kinds of images to generate. 
We discuss such models in Section 25.2. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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Figure 25.3: Some sample images generated by a small diffusion model trained on MNIST for about 1 hour 
on 1 GPU. Generated by diffusion mnist.ipynb. Used with kind permission of Winnie Xu. 


25.2 Conditional diffusion models 


In this section, we discuss how to generate samples from a diffusion model where we condition on 
some side information c, such as a class label or text prompt. Our presentation is based in part on 


zo [Die22]. 


31 25.2.1 Classifier guidance 


The simplest way to do this is to train the diffusion model to maximize the conditional likelihood 
plæ|e). We can do this by making all the encoder and decoder terms depend on c as an auxiliary 
input. Unfortunately, this gives us limited control over the generation process. 

An alternative approach, known as classifier guidance, was proposed in [DN21]. This leverages 
Bayes rule to convert the conditional distribution p(æ|c) into an unconditional term p(a) and a 


z. likelihood p(c|a). Hence 


log p(a|c) = log p(c|a) + log p(x) — log p(c) (25.64) 


* so the score function becomes 


D 
for) 


Vz log p(a|c) = Vz log p(x) + Vz log p(c|x) (25.65) 


45 If we have a pre-trained image classifier, p(c|a), then we can use it to convert an unconditional 
46 generative model into a conditional one. We can further amplify the influence of the conditioning 
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25.2. CONDITIONAL DIFFUSION MODELS 


signal by scaling it by a factor w > 1: 
Va log py (a\c) = Vz log p(x) + WV x log p(c|x) (25.66) 


In [DN21], they found that setting w ~ 10 gave samples with much higher perceptual quality than 
using w = 1. 


25.2.2 Classifier-free guidance 


Unfortunately, p(c|x) is a discriminative model, that may ignore many details of the input x. Hence 
optimizing along the directions specified by Vz log p(c|x) can give poor results. In [HS21], they 
proposed a technique called classifier-free guidance. Here the idea is to derive the classifier from 
the conditional generative model. That is, we use Bayes rule to compute p(c|x%) = p(x|c)p(c)/p(x), 
and hence 


log p(cla) = log p(æ|e) + log p(c) — log p(w) (25.67) 
so the score of the induced classifier becomes 

Vz log p(c|x) = Vz log p(a|c) — Vz log p(x) (25.68) 
Plugging in the scaled guidance signal gives 


Va log pw (alc) = Va log p(x) + w(Vz log p(x|c) — Vz log p(x)) (25.69) 
= (1 — w)V x log p(x) + wV log p(a|c) (25.70) 


For w = 0 we recover the unconditional model, and for w = 1 we recover the standard conditional 
model. However, the best results are obtained when w > 1. 

In practice, we can implement this method with a small change to the ancestral sampling routine. 
At each step t, we predicted the guided error term, which has the form 


č = weg(Zz,t,c) + (1 — w)eo(z:,t) (25.71) 


We then compute the sample 
zs = figel2s) + [6%,)- ë: (25.72) 


where /1,),(Z+) is defined in terms of ê, as defined in Equation (25.30), and y controls the stochasticity 
of the sampler [ND21]. 

Instead of training a conditional and unconditional model, we can just train a single conditional 
model, provided we set c = Ú to emulate unconditional sampling. The advantage of this approach is 
that we have a single, self-consistent model, where the guidance signal is derived from a generative 
model, for which gradients in input space are more meaningful. This classifier-free guidance method 
is used several SOTA conditional image generation models, as we discuss in Section 25.2.3. 
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25.2.3 Conditional image generation 


Recently several large text-conditioned image diffusion models have been created, which use the 
classifier guidance trick to get very impressive results. Examples include Open Al’s DALL-E 2 
model [Ram-+22] and Google’s Imagen model [Sah+22]. The latter uses a large pre-trained text 
encoder, based on T5-XXL [Raf+20a], combined with a VDM model (Section 25.1) based on the 
U-net architecture. The base model generates images of size 64 x 64. These are then upsampled to 
256 x 256 and then 1024 x 1024 by two separately trained image-to-image diffusion models [Ho+21]. 
Some samples from the Imagen model are shown in Figure 20.3. 

In addition to conditioning on text, it is possible to condition on another image to create an 
image-to-image mapping model. For example, we can map a gray-scale image c to a color image x, 
or a corrupted or occluded image c to a clean version x. This can be done by training a multi-task 
conditional diffusion model, as explained in [Sah+21]. See Figure 25.4 for some sample outputs. 


25.2.4 Other forms of conditional generation 


In addition to generating images, diffusion models can be used to generate audio sequences. Some 
recent examples of this are DiffWave [Kon+21] and wavegrad [Che+21]. 

They can also be used to generate video sequences. Some recent examples of this are [Ho+22; 
VJMP22; Har+22]. 


25.3 Speeding up the generation process 


DDMs are relatively fast to train, since we can just sample a random layer and optimize a squared 
loss between its input and output, as shown in Section 25.1.3.4. However, DDMs are slow to use for 


— generation, since they require many layers T to get good results. 


Several different strategies have been proposed to speed up generation from DDMs. In [SME21], 


= they propose a method called DDIM (Denoising Diffusion Implicit Models) which “turns off” the 


NID lor Te Jo IN Ie IO [© |e IN Im Jo TA Iw N e Io o 


= sampling process at all layers of the model except for the initial prior, ær. This makes the sampling 
— process deterministic, and allows the user of fewer steps when performing generation. 


In [SH22], they show how to distill a deterministic diffusion sampler into one that takes half as 


— many steps. This process can be repeated recursively. They are able to take a SOTA model that 
— takes 8192 steps and distill it down into a model that just takes 4 steps, while maintaining high 
— quality. 


Another approach is proposed in [DB+21], based on the concept of a Schrödinger bridge. The 


= goal is to find the joint distribution over paths 7(ao.7) that is as close as possible to the forward 
=~ diffusion process q(£o:r), but which satisfies the marginal constraints mo = pp and TT = Pref, where 
— Pref is the target or reference distribution. This can be solved using the iterative proportional fitting 
— algorithm, which alternates between matching each of the two marginal constraints while optimizing 
— a. Standard methods for fitting diffusion models correspond to a single step of this algorithm, but 
— by performing more iterations, it is possible to fit much shallower models. 


Another approach to speeding up diffusions is to replace the Gaussian noise process with a more 


= expressive nonlinear mapping, implemented by a GAN. This idea is explored in denoising diffusion 


GAN paper [XK V22], where they are able to get speedups of a factor of 2000. 
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25.3. SPEEDING UP THE GENERATION PROCESS 


Colorization 


Inpainting 


Uncropping 


JPEG restoration 


Figure 25.4: Illustration of some image-to-image tasks using the Palette conditional diffusion model. From 
Figure 1 of [Sah+21]. Used with kind permission of Chitwan Saharia. 
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26 Generative adversarial networks 


This chapter is written by Mihaela Rosca, Shakir Mohamed and Balaji Lakshminarayanan. 


26.1 Introduction 


In this chapter, we focus on implicit generative models, which are a kind of probabilistic model 
without an explicit likelihood function [ML16]. This includes the family of Generative Adversarial 
Networks or GANs [Gool6]. In this chapter, we provide an introduction to this topic, focusing on 
a probabilistic perspective. 

To develop a probabilistic formulation for GANs, it is useful to first distinguish between two types of 
probabilistic models: “prescribed probabilistic models” and “implicit probabilistic models” 
[DG84]. Prescribed probabilistic models, which we will call explicit probabilistic models, provide 
an explicit parametric specification of the distribution of an observed random variable x, specifying 
a log-likelihood function log gg(a) with parameters 8. Most models we encountered in this book thus 
far are of this form, whether they be state-of-the-art classifiers, large-vocabulary sequence models, 
or fine-grained spatio-temporal models. Alternatively, we can specify an implicit probabilistic 
model that defines a stochastic procedure to directly generate data. Such models are the natural 
approach for problems in climate and weather, population genetics, and ecology, since the mechanistic 
understanding of such systems can be used to directly describe the generative model. We illustrate 
the difference between implicit and explicit models in Figure 26.1. 

The form of implicit generative models we focus on in this chapter can be expressed as a probabilistic 
latent variable model, similar to VAEs (Chapter 21). Implicit generative models use a latent variable 
z and transform it using a deterministic function G that maps from R™ — R? using parameters 0. 
Implicit generative models do not include a likelihood function or observation model. Instead, the 
generating procedure defines a valid density on the output space that forms an effective likelihood 
function: 


x = Ge(z'); 2’ ~q(z) (26.1) 

ð ð 
ao Pate dz, 26.2 
qo(@) = a+ ae ore q(z)dz (26.2) 


where q(z) is a distribution over latent variables that provides the external source of randomness. 
Equation (26.2) is the definition of the transformed density qọ(x) defined as the derivative of a 
cumulative distribution function, and hence integrates the distribution q(z) over all events defined 
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(a) Prescribed generative model. (b) Implicit generative model. 


Figure 26.1: Visualizing the difference between prescribed and implicit generative models. Prescribed mod- 
els provide direct access to the learned density (sometimes unnormalized). Implicit models only provide 
access to a simulator which can be used to generate samples from an implied density. Generated by 
genmo_ types_implicit_ explicit. ipynb 


by the set {Gg(z) < x}. When the latent and data dimension are equal (m = d) and the function 
Go(z) is invertible or has easily characterized roots, we recover the rule for transformations of 
probability distributions. This transformation of variables property is also used in normalizing flows 
(Chapter 23). In diffusion models (Chapter 25), we also transform noise into data and vice versa, but 
the transformation is not strictly invertible. 

We can develop more general and flexible implicit generative models where the function G is 
a non-linear function with d > m, e.g., specified by a deep network. Such models are sometimes 
called generator networks or generative neural samplers; they can also be throught of as 
differentiable simulators. Unfortunately the integral (26.2) is intractable in these kinds of models, 
and we may not even be able to determine the set {Gg(z) < x}. Of course, intractability is also a 


>, challenge for explicit latent variable models such as VAEs (Chapter 21), but in the GAN case, the 


lack of a likelihood term makes the learning problem even harder. Therefore this problem is called 
likelihood-free inference or simulation-based inference. 

Likelihood-free inference also forms the basis of the field known as Approximate Bayesian 
Computation or ABC, which we briefly discuss in Section 13.6.5. ABC and GANSs give us two 
different algorithmic frameworks for learning in implicit generative models. Both approaches rely on 
a learning principle based on comparing real and simulated data. This type of learning by comparison 
instantiates a core principle of likelihood-free inference, and expanding on this idea is the focus of 
the next section. The subsequent sections will then focus on GANs specifically, to develop a more 
detailed foundation and practical considerations. (See also https: //poloclub. github. io/ganlab/ 
for an interactive tutorial.) 


43 26.2 Learning by Comparison 


In most of this book, we rely on the principle of maximum likelihood for learning. By maximizing 


46 the likelihood we effectively minimize the KL divergence between the model gg (with parameters 0) 
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26.2. LEARNING BY COMPARISON 


+ model samples + model samples 
e data e data 
towe H + osno ese 0+ oatooe 


Figure 26.2: The aim of implicit generative modelling objectives: to measure distances between distributions 
only from samples, in order to distinguish between distributions which are further apart (left) compared to 
those which are closer (right). 


Learning By 
Comparison 
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Integral Probability 
Metrics 


f-Divergences Moment Matching 


Figure 26.3: Overview of approaches for learning in implicit generative models 


and the unknown true data distribution p*. Recalling equation (26.2), in implicit models we cannot 
evaluate qọ(x), and thus cannot use maximum likelihood training. As implicit models provide a 
sampling procedure, we instead are searching for learning principles that only use samples from the 
model. 

Figure 26.2 shows that the task of learning in implicit models is to determine, from two sets of 
samples, whether their distributions are close to each other and to quantify the distance between 
them. We can think of this as a ‘two sample’ or likelihood-free approach to learning by comparison. 
There are many ways of doing this, including using distributional divergences or distances through 
binary classification, the method of moments, and other approaches. Figure 26.3 shows an overview 
of different approaches for learning by comparison. 


26.2.1 Guiding principles 


We are looking for objectives D(p*,q) that satisfy the following requirements: 
1. Provide guarantees about learning the data distribution: argmin, D(p*, q) = p*. 
2. Can be evaluated only using samples from the data and model distribution. 


3. Are computationally cheap to evaluate. 
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Many distributional distances and divergences satisfy the first requirement, since by definition they 
satisfy the following: 


D(p*,q) = 0; D(p",q) =0 => p*=q (26.3) 


Many distributional distances and divergences, however, fail to satisfy the other two requirements: 
they cannot be evaluated only using samples — such as the KL divergence, or are computationally 
intractable — such as the Wasserstein distance. The main approach to overcome these challenges is 
to approximate the desired quantity through optimization by introducing a comparison model—often 
called a discriminator or a critic D, such that: 


D(p*,q) = menor (D,p*,@) (26.4) 


where F is a functional that depends on p* and q only through samples. For the cases we discuss, both 
the model and the critic are parametric with parameters 0 and @¢ respectively; instead of optimizing 
over distributions or functions, we optimize with respect to parameters. For the critic, this results in 
the optimization problem argmaxg F (Do, p*, qo). For the model parameters 0, the exact objective 
D(p*, qa) is replaced with the tractable approximation provided through the use of Dg. 

A convenient approach to ensure that F(Dg,p*,q9) can be estimated using only samples from the 
model and the unknown data distribution is to depend on the two distributions only in expectation: 


F(Do, p“, qo) = in(a) f (£, Q) 2J tgo (æ) I£, ) (26.5) 


where f and g are real valued functions whose choice will define F. In the case of implicit generative 
models, this can be rewritten to use the sampling path x = Ge(z), z ~q(z): 


F (Do, p“, 40) = Epa) f (2, $) + Egc)g(Go (z2), P) (26.6) 


which can be estimated using Monte Carlo estimation 


N M 
F (Do.v",40) © 7 SEd) + 77D 9(GoB).6); Bi ~ ple); Bwal) (26.7) 
i=l w=1 


Next, we will see how to instantiate these guiding principles in order to find the functions f 


22 and g and thus the objective F which can be used to train implicit models: class probability 


estimation (Section 26.2.2), bounds on f-divergences (Section 26.2.3), Integral Probability Metrics 


= (Section 26.2.4) and moment matching (Section 26.2.5). 


26.2.2 Density ratio estimation using binary classifiers 


One way to compare two distributions p* and gg is to compute their density ratio r(x) = a a, The 


distributions are the same if and only if the ratio is 1 everywhere in the support of gg. Since we 


= cannot evaluate the densities of implicit models, we must instead develop techniques to compute the 
= density ratio from samples alone, following the guiding principles established above. 


Fortunately, we can use the trick from Section 2.7.5 which converts density estimation into a binary 


* classification problem to write 


p(x) D(x) 


yula] =a D(a) (26.8) 
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26.2. LEARNING BY COMPARISON 


Loss Objective Function (D := D(x; p) € [0,1]) 
Bernoulli loss p*(æ) log D] + Eq, (x) [log(1 — D)] 
Brier score p(x) [—(1 — D)?] + Egg (n)[—D?] 
: Í 1 
Exponential loss | E,+2) (-4+) 3] + Ey (a) |(-x25) ' 
Misclassification | E,-(~)[—I[D < 0.5]] + Eq,()[—I[D > 0.5]] 
Hinge loss 2p*(æ) | — max Q 1 — log 2) + Ey (æ) |- max (0, 1+ log 2) 
Spherical pa laD] +E lal- D)]; a= (1—2D 4+ 2D?)~? 


Table 26.1: Proper scoring rules that can be maximized in class probability-based learning of implicit generative 
models. Based on [ML16]. 


where D(a) is the discriminator or critic which is trained to distinguish samples coming from p* vs 
qo. 

For parametric classification, we can learn discriminators Dg(a) € [0,1] with parameters @. Using 
knowledge and insight about probabilistic classification, we can learn the parameters by minimizing 
any proper scoring rule [GRO7b] (see also Section 14.2.1). For the familiar Bernoulli log-loss (or 
binary cross entropy loss), we obtain the objective: 


V (qo, p“) =arg mat ‘y(aly)p(y) ly log Do(x) + (1 — y) log(1 — De (æ))] 


= “p(#|y=1)p(y=1) 108 Do (£) + Ep(w|y=0)p(y=0) log(1 — Dg(x)) (26.9) 
25 
- i 1 
26 =arg mar 7 Epa) log Dg(«) + JE c0(w) log(1 — Dg(a)). (26.10) 
27 
28 The same procedure can be extended beyond the Bernoulli log-loss to other proper scoring rules 
a used for binary classification, such as those presented in Table 26.1, adapted from [ML16]. The 
30 . Bocce : p*(a) a 
31 optimal discriminator D is Peji aj Silice: 
32 * * * 
w Ve a (26.11) 
34 œe) 1- D*(x) p(x) + qo (x) 
35 
36 By substituting the optimal discriminator into the scoring rule (26.10), we can show that the objective 
37 V can also be interpreted as the minimization of the Jensen-Shannon divergence. 
A 1 pæ) 1 p*(æ) 
39 V* (qo; p = =E p+) [log + -E,, (x) (log(1 — ——— 26.12 
fo | P= iro DOB TiC) + gla)! * g elel aa) + olay? ci 
41 pe m p(x) Ls qo(@) 
iż = z Up*(a) log ORAO + z “go (a) llog( ener)! = log 2 (26.13) 
= 2 2 
43 
= 1 eh 1 SE 
4d = 5 Dut (r | 2 5 2) + 3D (u I ze) — log 2 (26.14) 
45 
46 = JS D(p*, qo) — log 2 (26.15) 
AT 
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where JSD denotes the Jensen-Shannon divergence: 


JISD(p*, qo) = Dua. (0 IS z) + 3D (w I- z) (26.16) 


This establishes a connection between optimal binary classification and distributional divergences. 
By using binary classification, we were able to compute the distributional divergence using only 
samples, which is the important property needed for learning implicit generative models; as expressed 
in the guiding principles (Section 26.2.1), we have turned an intractable estimation problem — how 
to estimate the JSD divergence, into an optimization problem — how to learn a classifier which can 
be used to approximate that divergence. 

We would like to train the parameters 0 of generative model to minimise the divergence: 


min JSD(p*, qo) = min V* (qo, p“) + log 2 (26.17) 
Lr + Le * 
= min SE pe) log D* (a) + JE co (w) log(1 — D*(a)) + log 2 (26.18) 


Since we do not have access to the optimal classifier D* but only to the neural approximation Dg 
obtained using the optimization in (26.10) , this results in a min-max optimization problem: 


la 1. 
min max 3 ip(a) log Dg (a)] + JE a0(@) flog(1 — Dg(x))] (26.19) 


By replacing the generating procedure (26.1) in (26.19) we obtain the objective in terms of the 


26 latent variables z of the implicit generative model: 
ie Lo 
min mak > i p*(æ) [log De(£)] + 5 Ng(z) log(1 — Dg(Ge(z)))I, (26.20) 


31 which recovers the definition proposed in the original Generative Adversarial Network (GAN)[Goo-+ 14]. 
32 The core principle behind GANSs is to train a discriminator, in this case a binary classifier, to 
33 approximate a distance or divergence between the model and data distributions, and to then train 
34 the generative model to minimize this approximation of the divergence or distance. 


Beyond the use of the Bernoulli scoring rule used above, other scoring rules have been used to 


36 train generative models via min-max optimization. The Brier scoring rule, which under discriminator 
37 optimality conditions can be shown to correspond to minimizing the Pearson y? divergence via 
38 similar arguments as the ones shown above has lead to LS-GAN [Mao+17]. The hinge scoring rule 


has become popular [Miy+18b; BDS18], and under discriminator optimality conditions corresponds 


40 to minimizing the total variational distance [NWJ+09]. 


The connection between proper scoring rules and distributional divergences allows the construction 


42 of convergence guarantees for the learning criteria above, under infinite capacity of the discriminator 
43 and generator: since the minimizer of distributional divergence is the true data distribution (Equa- 


tion 26.3), if the discriminator is optimal and the generator has enough capacity, it will learn the 


45 data distribution. In practice however, this assumption will not hold, as discriminators are rarely 
46 optimal; we will discuss this at length in Section 26.3. 
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26.2. LEARNING BY COMPARISON 


26.2.3 Bounds on f-divergences 


As we saw with the appearance of the Jensen-Shannon divergence in the previous section, we can 
consider directly using a measure of distributional divergence to derive methods for learning in 
implicit models. One general class of divergences are the f-divergences (Section 2.7.1) defined as: 


D;lp'(e)lao(e)] = f wle)s (ZE) aw (26.21) 


qo(x) 


where f is a convex function such that f(1) = 0. For different choices of f, we can recover known 
distributional divergences such as the KL, reverse KL, and Jensen Shannon divergence. We discuss 
such connections in Section 2.7.1, and provide a summary in Table 26.2. 

To evaluate Equation (26.21) we will need to evaluate the density of the data p*(a) and the model 
qo(a), neither of which are available. In the previous section we overcame the challenge of evaluating 
the density ratio by transforming it into a problem of binary classification. In this section, we will 
instead look towards the role of lower bounds on f-divergences, which is an approach for tractability 
that is also used for variational inference (Chapter 10). 

f-divergences have a widely-developed theory in convex analysis and information theory. Since the 
function f in Equation (26.21) is convex, we know that we can find a tangent that bounds it from 
below. The variational formulation of the f-divergence is [NWJ10b; NCT16c]: 


Di pleu] = f wle)s (ZE) aw (26.22) 


qo (x) 
a fa PET hoë wi ee dee (26.23) 
t:X>R (a) 
= f sup Pte) aoa) f(t(e) ae (26.24) 
2 R i (x)] ~ Ugo (2) [ft (t(x))]. (26.25) 


In the second line we use the result from convex analysis, discussed Supplementary Section 6.3, that 
re-expresses the convex function f using f(u) = sup, ut — f' (t), where ft is the convex conjugate of 
the function f, and t is a parameter we optimize over. Since we apply f at u = ” pie for alla € X, 
we make the parameter t be a function t(x). The final inequality comes from rapla ne the supremum 
over all functions from the data domain ¥ to R with the supremum over a family of functions 7 
(such as the family of functions expressible by a neural network architecture), which might not be 
able to capture the true supremum. The function t takes the role of the discriminator or critic. 

The final expression in Equation (26.25) follows the general desired form of Equation 26.5: it is the 
difference of two expectations, and these expectations can be computed by Monte Carlo estimation 
using only samples, as in Equation (26.7); despite starting with an objective (Equation 26.21) which 
contravened the desired principles for training implicit generative models, variational bounds have 
allowed us to construct an approximation which satisfies all desiderata. 

Using bounds on the f-divergence, we obtain an objective (26.25) that allows learning both the 
generator and critic parameters. We use a critic D with parameters @ to estimate the bound, and 
then optimize the parameters 0 of the generator to minimise the approximation of the f-divergence 
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Divergence f fi Optimal Critic 

KL ulogu gut 1+ logr(a) 

Reverse KL — logu 1 — log(—u) 1/r(æ) 

JSD ulogu — (u + 1)log “¥ | —log(2- e") | mtir) 

Pearson x? (u—1)? tu?’ +u (vVræ) — 1) 1/r(a) 


Table 26.2: Standard divergences as f divergences for various choices of f. The optimal critic is written as a 
function of the density ratio r(x) = 2\® 


provided by the critic (we replace t above with Dg, to retain standard GAN notation): 
min D (p", qo) > min max Epa) [Do(2)] — Egy al" (Dela) (26.26) 


= min max Epa) [Do(2)] — Ege lf! (Do(Gol2)))] (26.27) 


This approach to train an implicit generative model leads to f-GANs [NCT16c]. It is worth noting 
that there exists an equivalence between the scoring rules in the previous section and bounds on 
f-divergences [RW11]: for each scoring rule we can find an f-divergence that leads to the same 
training criteria and the same min-max game of Equation 26.27. An intuitive way to grasp the 
connection between f-divergences and proper scoring rules is through their use of density ratios: 
in both cases the optimal critic approximates a quantity directly related to the density ratio (see 
Table 26.2 for f-divergences and Equation (26.11) for scoring rules). 


26.2.4 Integral probability metrics 


2 Instead of comparing distributions by using their ratio as we did in the previous two sections, we 
£. can instead study their difference. A general class of measure of difference is given by the Integral 
28 Probability Metrics (Section 2.7.2) defined as: 


I¢(p*(x), qo(x)) = sup | Lp(a) f (£) — ‘ao (a) f ()| ; (26.28) 
FEF 


The function f is a test or witness function that will take the role of the discriminator or critic. To 
use IPMs we must define the class of real valued, measurable functions F over which the supremum 
is taken, and this choice will lead to different distances, just as choosing different convex functions f 
leads to different f-divergences. Integral probability metrics are distributional distances: beyond 
satisfying the conditions for distributional divergences D(p*, q) > 0; D(p*,q) =0 => p*=q 
(Equation (26.3)), they are also symmetric D(p,q) = D(q,p) and satisfy the triangle inequality 
D(p,q) < D(p,r) + Dr, q). 

Not all function families satisfy these conditions of create a valid distance [7. To see why consider 
the case where F = {z} where z is the function z(a) = 0. This choice of F entails that regardless of 
the two distributions chosen, the value in Equation 26.28 would be 0, violating the requirement that 
distance between two distributions be 0 only if the two distributions are the same. A popular choice 
of F for which Ip satisfies the conditions of a valid distributional distance is the set of 1-Lipschitz 
functions, which leads to the Wasserstein distance [Vil08]: 


Wi (p(x), qo()) = sup E(x) f (£) ~~ tola) f (a) (26.29) 
Fflaps 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID [o [A [wo N e 


N JIN JIN JN JIN Je je JR Je Je Je Je j= j= 
IIS IS IS IS le le S ls la le le Ie lE 


A Jè Jẹ Jè Jẹ Jẹ Jẹ Jè Jw Jw jw jw jœ j% jœ j jw j jN jo N j j 
ISIS A IÈ Iè IS [e [8 1S le S 18 la le le le e ls 18 Ie IS S IS 


26.2. LEARNING BY COMPARISON 


— p — qe —F — p — qe —F 
(a) Optimal Wasserstein critic. (b) Optimal MMD critic. 


Figure 26.4: Optimal critics in Integral Probability Metrics (IPMs). Generated by ipm_ divergences.ipynb 


We show an example of a Wasserstein critic in Figure 26.4a. The supremum over the set of 
1-Lipschitz functions is intractable for most cases, which again suggests the introduction of a learned 
critic: 


WwW, (p(x), qo()) = sup ip*(aæ) f (£) _ Ego (x) f (£) (26.30) 
fillfllnipS2 


præ) D(x) — Eq, (a) Do (£), (26.31) 


= max 
$:||Dalltip <1 


where the critic Dg has to be regularized to be 1-Lipschitz (various techniques for Lipschitz regular- 
ization via gradient penalties or spectral normalization methods have been used [ACB17; Gul+17]). 
As was the case with f-divergences, we replace an intractable quantity which requires a supremum 
over a class of functions with a bound obtained using a subset of this function class, a subset which 
can be modeled using neural networks. 

To train a generative model, we again introduce a min max game: 


Epa g(x) — Ego) Do (T) (26.32) 


pæ) Do(#) — Egz) Do(Go(2)) (26.33) 


min W,(p*(a),qo(x)) > min max 
0 (p(w), gol) 9 ¢:||Dellrip S1 


min max 
@ :\|Dalltip <t 


This leads to the popular WassersteinGAN [ACB17]. 
If we replace the choice of function family F to that of functions in an RKHS (Section 18.3.7.1) 
with norm one, we obtain the maximum mean discrepancy (MMD) discussed in Section 2.7.3: 


MMD(p*\(x),qo(z)) = sup pæ) f (#) — Eq (a) f(a). (26.34) 


F:lfil;«as=1 


We show an example of an MMD critic in Figure 26.4b. It is often more convenient to use the 
square MMD loss [LSZ15; DRG15], which can be evaluated using the kernel K (Section 18.3.7.1): 


MMD?(p*, go) = Ep») Ee) K (x, 1) — 2E pæ) Ego (y)K(@, Y) + Ego (y)Eqo (yy (ys y) (26.35) 
= Lp*(æ) Town C(L, x") —2 Lp*(æ) ig(z)K(#, Go(z)) + Ie) tae) K(Ge(z), Ga(z2’)) 


(26.36) 
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The MMD can be directly used to learn a generative model, often called a generative matching 
network [LSZ15]: 


min MMD?(p*, go) (26.37) 


The choice of kernel is important. Using a fixed or predefined kernel such as a radial basis function 
(RBF) kernel might not be appropriate for all data modalities, such as high dimensional images. 
Thus we are looking for a way to learn a feature function Ç such that K(¢(a),¢(x’)) is a valid kernel; 
luckily, we can use that for any characteristic kernel K(x, x’) and injective function Ç, K(¢(a), ¢(#"))) 
is also a characteristic kernel. While this tells us that we can use feature functions in the MMD 
objective, it does not tell us how to learn the features. In order to ensure that the learned features 
are sensitive to differences between the data distribution p*(a) and the model distribution qg(x), the 
kernel parameters are trained to mazimize the square MMD. This again casts the problem into a 
familiar min max objective by learning the projection ¢ with parameters œ [Li+17b]: 


minMMD,*(pp, qo) (26.38) 
= min max Ep*(x)Ep*(x')K (Co (x), Co(’)) 


— 2E p(x) Ege (yy (Co(@), Co(y)) (26.39) 
+ Ege (y) Ego (yy K(Co(y), Co(y’)) 


where ¢y is regularized to be injective, though this is sometimes relaxed [Bin+18]. Unlike the 
Wasserstein distance and f-divergences, Equation (26.39) can be estimated using Monte Carlo 
estimation, without requiring a lower bound on the original objective. 


= 26.2.5 Moment matching 


More broadly than distances defined by integral probability metrics, for a set of test statistics s, one 
can define a moment matching criteria [Pea36], also known as the method of moments: 


min ||Ep-(x)$(@) — usw) (26.40) 


34 where m(@) = E,,(~)8(#) is the moment function. The choice of statistic s(a) is crucial, since as with 
35 distributional divergences and distances, we would like to ensure that if the objective is minimized 
36 and reaches the minimal value 0, the two distributions are the same p*(x) = qg(x). Too see that not 
37 all functions s satisfy this requirement consider the function s(a) = æ: simply matching the means of 
38 two distributions is not sufficient to match higher moments (such as variance). For likelihood based 
39 models the score function s(x) = log qg(a) satisfies the above requirement and leads to a consistent 
40 estimator [Vaa00], but this choice of s is not available for implicit generative models. 


This motivates the search for other approaches of integrating the method of moments for implicit 


42 models. The MMD can be seen as a moment matching criteria, by matching the means of the 
43 two distributions after lifting the data into the feature space of an RHKS. But moment matching 


can go beyond integral probability metrics: Ravuri et al. [Rav-+18] show that one can learn useful 


45 moments by using s as as the set of features containing the gradients of a trained discriminator 
46 classifier Dg together with the features of the learned critic: sg(a) = [V¢D¢(a), hi(x),...,hn(x)| 
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26.2. LEARNING BY COMPARISON 


WO 6% 006 +O W A Sa 


(a) Failure of the KL divergence to distinguish (b) The density ratio = used by the KL divergence and 


between distributions with non-overlapping support: a smooth estimate given by an MLP, together with the 


Ps iil i i : 
Dui (v* || 40,) = Dx (p* || G0.) = 00, despite qo, being gradient it provides with respect to the input variable. 
closer to p* than qo, - 


Figure 26.5: The KL divergence cannot provide learning signal for distributions without overlapping support 
(left), while the smooth approximation given by a learned decision surface like an MLP can (right). Generated 
by ipm_ divergences.ipynb 


where hi(x),...,An(a) are the hidden activations of the learned critic. Both features and gradients 
are needed: the gradients VzgD¢(a) are required to ensure the estimator for the parameters 0 is 
consistent, since the number of moments s(x) needs to be larger than the number of parameters 
6, which will be true if the critic will have more parameters than the model; the features h;(x) 
are added since they have been shown empirically to improve performance, thus showcasing the 
importance of the choice of test statistics s used to train implicit models. 


26.2.6 On density ratios and differences 


We have seen how density ratios (Sections 26.2.2 and 26.2.3) and density differences (Section 26.2.4) 
can be used to define training objectives for implicit generative models. We now explore some of the 
distinctions between using ratios and differences for learning by comparison, as well as explore the 
effects of using approximations to these objectives using function classes such as neural networks has 
on these distinctions. 

One often stated downside of using divergences that rely on density ratios (such as f-divergences) 
is their poor behavior when the distributions p* and qọ do not have overlapping support. For 
non-overlapping support, the density ratio £ will be oo in the parts of the space where p*(x) > 0 
but go(x) = 0, and 0 otherwise. In that case, the Dx (p* || q0) = co and the JS'D(p*, qo) = log 2, 
regardless of the value of 8. Thus f-divergences cannot distinguish between different model distributions 
when they do not have overlapping support with the data distribution, as visualized in Figure 26.5a. 
This is in contrast with difference based methods such as IPMs such as the Wasserstein distance and 
the MMD, which have smoothness requirements built in the definition of the method, by constraining 
the norm of the critic (Equations (26.29) and (26.34)). We can see the effect of these constraints 
in Figure 26.4: both the Wasserstein distance and the MMD provide useful signal in the case of 
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distributions with non-overlapping support. 

While the definition of f-divergences relies on density ratios (Equation (26.21)), we have seen that 
to train implicit generative models we use approximations to those divergences obtained using a 
parametric critic Dg. If the function family of the critic used to approximate the divergence (via the 
bound or class probability estimation) contains only smooth functions, it will not be able to model 
the sharp true density ratio which jumps from 0 to oo, but instead provide a smooth approximation. 
We show an example in Figure 26.5b, where we show the density ratio for two distributions without 
overlapping support and an approximation provided by an MLP trained to approximate the KL 
divergence using Equation 26.25. Here, the smooth decision surface provided by the MLP can be 
used to train a generative model while the underlying KL divergence cannot be; the learned MLP 
provides the gradient signal on how to move distribution mass to areas with more density under 
the data distribution, while the KL divergence provides a zero gradient almost everywhere in the 
space. This ability of approximations to f-divergences to overcome non-overlapping support issues is 
a desirable property of generative modeling training criteria, as it allows models to learn the data 
distribution regardless of initialization [Fed+18]. Thus while the case of non-overlapping support 
provides an important theoretical difference between IPMs and f-divergences, it is less significant in 
practice since bounds on f-divergences or class probability estimation are used with smooth critics 
to approximate the underlying divergence. 

Some density ratio and density difference based approaches also share commonalities: bounds are 
used both for f-divergences (variational bounds in Equation 26.25) and for the Wasserstein distance 
(Equation (26.31)). These bounds to distributional divergence and distances have their own set of 
challenges: since the generator minimizes a lower bound of the underlying divergence or distance, 
minimizing this objective provides no guarantees that the divergence will decrease in training. To see 
this, we can look at Equation 26.26: its RHS can get arbitrarily low without decreasing the LHS, 
the divergence we are interested in minimizing; this is unlike variational upper bound on the KL 
divergence used to train Variational Autoencoders Chapter 21. 


22 26.3 Generative Adversarial Networks 


31 We have looked at different learning principles that do not require the use of explicit likelihoods, and 
32 thus can be used to train implicit models. These learning principles specify training criteria, but do 
33 not tell us how to train models or parametrize models. To answer these questions, we now look at 
34 algorithms for training implicit models, where the models (both the discriminator and generator) are 
35 deep neural networks; this leads us to Generative Adversarial Networks (GANs). We cover how to 
36 turn learning principles into loss functions for training GANs (Section 26.3.1); how to train models 
37 using gradient descent (Section 26.3.2); how to improve GAN optimization (Section 26.3.4) and how 
38 to assess GAN convergence (Section 26.3.5). 


~ 26.3.1 From learning principles to loss functions 


In Section 26.2 we discussed learning principles for implicit generative models: class probability 


43 estimation, bounds on f-divergences, Integral Probability Metrics and moment matching. These 


principles can be used to formulate loss functions to train the model parameters 0 and the critic 


45 parameters @. Many of these objectives use zero-sum losses via a min-max formulation: the 
46 generator’s goal is to minimize the same function the discriminator is maximizing. We can formalize 
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26.3. GENERATIVE ADVERSARIAL NETWORKS 


this as: 
min max V (@, 0) (26.41) 


As an example, we recover the original GAN with the Bernoulli log-loss (Equation (26.19)) when 


1 1 
V(#,8) = SEp2)llog Do(#)] + 5Eqo(a)llog(1 — Do(x))] (26.42) 

The reason most of the learning principles we have discussed lead to zero-sum losses is due to 
their underlying structure: the critic maximizes a quantity in order to approximate a divergence 
or distance — such as an f-divergence or Integral Probability Metric — and the model minimizes 
this approximation to the divergence or distance. That need not be the case, however. Intuitively, 
the discriminator training criteria needs to ensure that the discriminator can distinguish between 
data and model samples, while the generator loss function needs to ensure that model samples are 
indistinguishable from data according to the discriminator. 

To construct a GAN that is not zero-sum, consider the zero-sum criteria in the original GAN 
(Equation 26.42), induced by the Bernoulli scoring rule. The discriminator tries to distinguish between 
data and model samples by classifying the data as real (label 1) and samples as fake (label 0), while 
the goal of the generator is to minimize the probability that the discriminator classifies its samples 
as fake: ming Egs (æ) log(1 — Dg(a)). An equally intuitive goal for the generator is to maximize 
the probability that the discriminator classifies its samples as real. While the difference might 
seem subtle, this loss, known as the “non-saturating loss” [Goo+14], defined as E,,~) — log Dg (£), 
enjoys better gradient properties early in training, as shown in Figure 26.6: the non-saturating loss 
provides a stronger learning signal (via the gradient) when the generator is performing poorly, and 
the discriminator can easily distinguish its samples from data, i.e. D(G(z)) is low; more on the 
gradients properties the saturating and non-saturating losses can be found in [AB17; Fed+18]. 

There exist many other GAN losses which are not zero-sum, including formulations of LS- 
GAN [Mao+17], GANs trained using the hinge loss [LY17] and RelativisticGANs [JM18]. We 
can thus generally write a GAN formulation as follows: 


min Lp(, 6); min Lg(, 4). (26.43) 


We recover the zero-sum formulations if —Lp(@, 0) = Le(¢, 0) = V(¢,9). Despite departing from 
the zero-sum structure, the nested form of the optimization remains in the general formulation, as 
we will discuss in Section 26.3.2. 

The loss functions for the discriminator and generator, Lp and Lg respectively, follow the general 
form in Equation 26.5, which allows them to be used to efficiently train implicit generative models. 
The majority of loss functions considered here can thus be written as follows: 


Lp(o, 8) = Epria)9(Do(2)) +E q,(e)h(Do(@)) = Epa) I(Dol2)) + Eqeyh(Do(Golz))) (26.44) 
Le(¢, 4) = ogo (a) !(Do(x)) = tace) (De (Go(2)) (26.45) 
where g, h, L: R > R. We recover the original GAN for g(t) = — logt, h(t) = —log(1 — t) and 
I(t) = log(1 — t); the non-saturating loss for g(t) = — logt, h(t) = —log(1 — t) and I(t) = — log(t); 


the Wasserstein distance formulation for g(t) = t, h(t) = —t and I(t) = t; for f-divergences g(t) = t, 
h(t) = -f (t) and I(t) = fi (t). 
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(a) Generator loss as a function of 
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Figure 26.6: Saturating log(1 — D(G(z))) vs non-saturating — log D(G(z)) loss functions. The non-saturating 
loss provides stronger gradients when the discriminator is easily detecting that generated samples are fake. 
Generated by gan_loss_ types.ipynb 


26.3.2 Gradient Descent 


GANs employ the learning principles discussed above in conjunction with gradient based learning 
for the parameters of the discriminator and generator. We assume a general formulation with a 
discriminator loss function Lp(@,@) and a generator loss function Lg(@, 0). Since the discriminator 
is often introduced to approximate a distance or divergence D(p*, qo) (Section 26.2), for the generator 
to minimize a good approximation of that divergence one should solve the discriminator optimization 
fully for each generator update. That would entail that for each generator update one would first 


26 find the optimal discriminator parameters ¢* = argming Lp(,@) in order to perform a gradient 
=f update given by VeLa(¢", 0). Fully solving the inner optimization problem @* = argming Lp(¢, 0) 
== for each optimization step of the generator is computationally prohibitive, which motivates the use 
= of alternating updates: performing a few gradient steps to update the discriminator parameters, 
= followed by a generator update. Note that when updating the discriminator, we keep the generator 
2- parameters fixed, and when updating the generator, we keep the discriminator parameters fixed. We 
32 show a general algorithm for these alternative updates in Algorithm 40. 


Algorithm 40: General GAN training algorithm with alternating updates 


1 Initialize ¢, 0 

2 for each training iteration do 

3 for K steps do 

4 B Update the discriminator parameters @ using the gradient VgLp(@, 0); 


5 Update the generator parameters 0 using the gradient Ve Lalo, 0) 


= 6 Return ¢, 0 


We are thus interested in computing VgLp(¢,@) and VeLag(¢,@). Given the choice of loss 
functions follows the general form in Equations 26.44 and 26.45 both for the discriminator and 


46 generator, we can compute the gradients that can be used for training. To compute the discriminator 
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gradients, we write: 


VoLv(¢, 9) = Vo [Ep(a)9(Do(@)) + Eqg(2)h(Do(x))| (26.46) 
= Epa) Vo9(Do(@)) + Eqo(z) Voh(Do(x)) (26.47) 


where V¢g(Do(«)) and Vgh(D¢(x)) can be computed via backpropagation, and each expectation 
can be estimated using Monte Carlo estimation. For the generator, we would like to compute the 
gradient: 


Lalo, 0) = VeEq(x)!(Do(2)) (26.48) 


Here we cannot change the order of differentiation and integration since the distribution under 
the integral depends on the differentiation parameter 0. Instead, we will use that qọ(x) is the 
distribution induced by an implicit generative model (also known as the “reparametrization trick”, 
see Section 6.5.4): 


VoLla(¢, 8) = Vo ‘go(a)!(Do(2)) = Vo Vg(z)!(De(Go(z))) = 2al) Vol(De(Go(z))) (26.49) 


and again use Monte Carlo estimation to approximate the gradient using samples from the prior 
q(z). Replacing the choice of loss functions and Monte Carlo estimation in Algorithm 40 leads to 
Algorithm 41, which is often used to train GANs. See https://github.com/probml/pyprobm1/ 
tree/master/gan for some sample (PyTorch) code which trains various kinds of (convolutional) 
GANs on the CelebA face dataset. 


Algorithm 41: GAN training algorithm 


1 Initialize @, 0 

2 for each training iteration do 

3 for K steps do 

4 Sample minibatch of M noise vectors zm ~ q(Z) 

5 Sample minibatch of M examples £m ~ p*(x) 

6 Update the discriminator by performing stochastic gradient descent using this gradient: 


Vot D [9(Do(@m)) + Voh(De(Go(zm)))]. 


7 Sample minibatch of M noise vectors zm ~ q(z) 
8 Update the generator by performing stochastic gradient descent using this gradient: 
M 
L Vor Xm- (Dg (Go(2m)). 
9 Return ¢, 0 


26.3.3 Challenges with GAN training 


Due to the adversarial game nature of GANs the optimizing dynamics of GANs are both hard to 
study in theory, and to stabilize in practice. GANs are known to suffer from mode collapse, a 
phenomenon where the generator converges to a distribution which does not cover not all the modes 
(peaks) of the data distribution, thus the model underfits the distribution. We show an example in 
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Figure 26.7: Illustration of mode collapse and mode hopping in GAN training. (a) The dataset, a mixture of 
16 Gaussians in 2 dimensions. (b-f) Samples from the model after various amounts of training. Generated by 
gan_miaxture_of_ gaussians.ipynb. 


27 Figure 26.7: while the data is a mixture of Gaussians with 16 modes, the model converges only to 
28 a few modes. Alternatively, another problematic behavior is mode hopping, where the generator 
29 “hops” between generating different modes of the data distribution. An intuitive explanation for 
30 this behavior is as follows: if the generator becomes good at generating data from one mode, it will 
31 generate more from that mode. If the discriminator cannot learn to distinguish between real and 
32 generated data in this mode, the generator has no incentive to expand its support and generate data 
33 from other modes. On the other hand, if the discriminator eventually learns to distinguish between 
34 the real and generated data inside this mode, the generator can simply move (hop) to a new mode, 
35 and this game of cat and mouse can continue. 


While mode collapse and mode hopping are often associated with GANs, many improvements have 


37 made GAN training more stable, and these behaviors more rare. These improvements include using 
38 large batch sizes, increasing the discriminator neural capacity, using discriminator and generator 
39 regularization, as well as more complex optimization methods. 


— 26.3.4 Improving GAN optimization 


43 Hyperparameter choices such as the choice of momentum can be crucial when training GANs, with 
44 lower momentum values being preferred compared to the usual high momentum used in supervised 
45 learning. Algorithms such as Adam|KB14al] provide a great boost in performance [RMC16a]. Many 
46 other optimization methods have been successfully applied to GANs, such as those which target 
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variance reduction [Cha+19b]; those which backpropagate through gradient steps, thus ensuring that 
generator does well well against the discriminator after it has been updated [Met+16]; or using a local 
bilinear approximation of the two player game [SA19]. While promising, these advanced optimization 
methods tend to have a higher computational cost, making them harder to scale to large models or 
large datasets compared to less efficient optimization methods. 


26.3.5 Convergence of GAN training 


The challenges with GAN optimization make it hard to quantify when convergence has occurred. In 
Section 26.2 we saw how global convergence guarantees can be provided under optimality conditions for 
multiple objectives constructed starting with different distributional divergences and distances: if the 
discriminator is optimal, the generator is minimising a distributional divergence or distance between 
the data and model distribution, and thus under infinite capacity and perfect optimization can learn 
the data distribution. This type of argument has been used since the original GAN paper [Goo+14] 
to connect GANs to standard objectives in generative models, and obtain the associated theoretical 
guarantees. From a game theory perspective, this type of convergence guarantee provides an existence 
proof of a global Nash equilibrium for the GAN game, though under strong assumptions. A Nash 
equilibrium is achieved when both players (the discriminator and generator) would incur a loss if 
they decide to act by changing their parameters. Consider the original GAN defined by the objective 


in Equation 26.19; then gg = p* and Dg(a) = es =i 
*(æ) 


for a given qọ, the ratio OAO the optimal discriminator (Equation 26.11), and given an 
optimal discriminator, the data distribution is the optimal generator as it is the minimizer of the 
Jensen-Shannon divergence (Equation 26.15). 

While these global theoretical guarantees provide useful insights about the GAN game, they do 
not account for optimization challenges that arise with accounting for the optimization trajectories 
of the two players, or for neural network parametrization since they assume infinite capacity both for 
the discriminator and generator. In practice GANs do not decrease a distance or divergence at every 
optimization step [Fed+18] and global guarantees are difficult to obtain when using optimization 
methods such as gradient descent. Instead, the focus shifts towards local convergence guarantees, 
such as reaching a local Nash equilibrium. A local Nash equilibrium requires that both players are at 
a local, not global minimum: a local Nash equilibrium is a stationary point (the gradients of the 
two loss functions are zero, i.e VgLp(¢, 0) = 0 and VeLa(¢, 8) = 0), and the eigenvalues of the 
Hessian of each player (V¢V¢Llp(¢, 9) and VeVeLa(¢, @)) are non-negative; for a longer discussion 
on Nash equilibria in continuous games see [RBS16]. For the general GAN game, it is not guaranteed 
that a local Nash equilibrium always exists [FO20], and weaker conditions such as stationarity or 
locally stable stationarity have been studied [Ber+19]; other equilibrium definitions inspired by game 
theory have also been used [JNJ20; HLC19]. 

To motivate why convergence analysis is important in the case of GANs, we visualize an example 
of a GAN that does not converge trained with gradient descent. In DiracGAN [MGN18a] the data 
distribution p*(x) is the Dirac delta distribution with mass at zero. The generator is modeling a 
Dirac delta distribution with parameter 0: Gg(z) = 0 and the discriminator is a linear function of the 
input with learned parameter ¢: Dg(x) = x. We also assume a GAN formulation where g = h = —I 
in the general loss functions Lp and Lg defined above, see Equations (26.44) and (26.45). This 


is a global Nash equilibrium, since 
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(a) The DiracGAN problem. (b) Alternating gradient descent on DiracGAN. 


Figure 26.8: Visualizing divergence using a simple GAN: DiracGAN. Generated by dirac_gan.ipynb 


results in the zero-sum game given by: 


Lp = Epa) — (De (2)) + goa) — (Do(a)) = —1(0) — 1(49) (26.50) 
Lg = Epa)l(De(2)) + Eqg(xy!(Do(a)) = +1(0) + 04) (26.51) 
27 where | depends on the GAN formulation used (l(z) = —log(1 + e77) for instance). The unique 


28 equilibrium point is 0 = ¢ = 0. We visualize the DiracGAN problem in Figure 26.8 and show that 
29 DiracGANs with alternating gradient descent (Algorithm 40) do not reach the equilibrium point, but 


instead takes a circular trajectory around the equilibrium. 
There are two main theoretical approaches taken to understand GAN convergence behavior around 


32 an equilibrium: by analyzing either the discrete dynamics of gradient descent, or the underlying 
33 continuous dynamics of the game using approaches such as stability analysis. To understand the 
34 difference between the two approaches, consider the discrete dynamics defined by gradient descent 
35 with learning rates ah and Ah, either via alternating updates (as we have seen in Algorithm 40): 


Pı = Gi_-1 — AhVeLd(Gi_1, 1-1), (26.52) 
0: = O:_1 — AhVo La (hi, Ot-1) (26.53) 


40 or simultaneous updates, where instead of alternating the gradient updates between the two players, 
41 they are both updated simultaneously: 


Pı = i_-1 — AhVeLd(Gi_1, 1-1), (26.54) 
0: = 0-1 — AhVoeLa(;_1, 4-1) (26.55) 


The above dynamics of gradient descent are obtained using Euler numerical integration from the 
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x equilibruim (0, 0) X equilibruim (0, 0) 


EEE Continuous dynamics EE Simultaneous gradient descent 


$ b 


Figure 26.9: Continuous (left) and discrete dynamics (right) take different trajectories in DiracGAN. Generated 
by dirac_ gan. ipynb 


ODEs that describes the game dynamics of the two players: 


$ = -VoLp(ġ, 0), (26.56) 
6 = -Vo Lal $, 0) (26.57) 


One approach to understand the behavior of GANs is to study these underlying ODEs which when 
discretized results in the gradient descent updates above, rather than directly study the discrete 
updates. These ODEs can be used for stability analysis to study the behavior around an equilibrium. 
This entails finding the eigenvalues of the Jacobian of the game 


J= Pees Bo 
~— |-VeVela(¢,9) —VeVela(?,@) 


evaluated at a stationary point (i.e. where VgLp(¢,0) =0 and VeLa(¢,@) = 0). If the eigenvalues 
of the Jacobian all have negative real parts, then the system is asymptotically stable around the 
equilibrium; if at least one eigenvalue has positive real part, the system is unstable around the 
equilibrium. For the DiracGAN, the Jacobian evaluated at the equilibrium 0 = ¢ = 0 is: 


z= VeVellOH) +10)  VeVell0p) + 1(0) J-| 0 G 


(26.58) 


26.59 
-VeVe OPH) +1(0)) -VeVe (184) +10)] = [-r0) 0 e 
where eigenvalues of this Jacobian are A+ = +il'(0). This is interesting, as the real parts of the 
eigenvalues are both 0; this result tells us that there is no asymptotic convergence to an equilibrium, 
but linear convergence could still occur. In this simple case we can reach the conclusion that 
convergence does not occur as we observe that there is a preserved quantity in this system, as 6? + ¢? 
does not change in time (Figure 26.9, left): 

d (0? + °?) do $ 


d $ 1 — 
y =V + 2b Gp = 20O + 20l (0¢4)0 =0. 
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Using stability analysis to understand the underlying continuous dynamics of GANs around 
an equilibrium has been used to show that explicit regularization can help convergence [NK17; 
Bal+18]. Alternatively, one can directly study the updates of simultaneous gradient descent shown 
in Equations 26.54 and 26.55. Under certain conditions [MNG17b] prove that GANs trained with 
simultaneous gradient descent reach a local Nash equilibrium [MNG17b]. Their approach relies on 
assessing the convergence of series of the form F*(x) resulting from the repeated application of 
gradient descent update of the form F(a) = « + hG(x), where h is the learning rate. Since the 
function F depends on the learning rate h, their convergence results depend on the size of the learning 
rate, which is not the case for continuous time approaches. 

Both continuous and discrete approaches have been useful in understanding and improving GAN 
training; however, both approaches still leave a gap between our theoretical understanding and the 
most commonly used algorithms to train GANs in practice, such as alternating gradient descent 
or more complex optimizers used in practice, like Adam. Far from only providing different proof 
techniques, these approaches can reach different conclusions about the convergence of a GAN: we 
show an example in Figure 26.9, where we see that simultaneous gradient descent and the continuous 
dynamics behave differently when a large enough learning rate is used. In this case, the discretization 
error — the difference between the behavior of the continuous dynamics in Equations 26.56 and 26.57 
and the gradient descent dynamics in Equations 26.54 and 26.55 — makes the analysis of gradient 
descent using continuous dynamics reach the wrong conclusion about DiracGAN [Ros+21]. This 
difference in behavior has been a motivator to train GANs with higher order numerical integrators 
such as RungeKutta4, which to more closely follow the underlying continuous system compared to 
gradient descent [Qin+20]. 

While optimization convergence analysis is an indispensable step in understanding GAN training 
and has led to significant practical improvements, it is worth noting that ensuring converge to an 
equilibrium does not ensure the model has learned a good fit of the data distribution. The loss 
landscape determined by the choice of Lp and Lg, as well as the parametrization of the discriminator 
and generator can lead to equilibria which do not capture the data distribution. The lack of 
distributional guarantees provided by game equilibria showcases the need to complement convergence 
analysis with work looking at the effect of gradient based learning in this game setting on the learned 
distribution. 


34 26.4 Conditional GANs 


36 We have thus far discussed how to use implicit generative models to learn a true unconditional 
37 distribution p*(a) from which we only have samples. It is often useful, however, to be able to learn 
38 conditional distributions of the from p*(a|y). This requires having paired data, where each input 
39 £n is paired with a corresponding set of covariates Yn, such as a class label, or a set of attributes or 
40 words, so D = { (£n, Yn) : n = 1: N}, as in standard supervised learning. The conditioning variable 
41 can be discrete - like a class label - or continuous - such as an embedding encoding information about 
42 past experience. Conditional generative models are appealing since we can specify that we want 
43 the generated sample to be associated with conditioning information y, making them very amenable 


to real world applications - see Section 26.7. 
To be able to learn implicit conditional distributions qg(a|y), we require datasets that specify the 


46 conditioning information associated with data as well as adapt model architectures and loss functions 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


I I$ IS x |S Its ie In Is lm = lk Is IE E IO lœ IN Im o e w N e 
A IW IN ie IO 1O lœ IN ID Jo Te Iw N IR O 


w j N JY JN JY JN 
łe |S 183 lè IS [8 IS 


32 


26.5. INFERENCE WITH GANS 


to learn conditional distributions. In the GAN case, changing the loss function for the generative 
model can be done by changing the critic, since the critic is part of the loss function of the generator; 
it is important for the critic to provide learning signal accounting for conditioning information, by 
penalizing a generator which provides realistic samples but which ignore the provided conditioning. 

If we do not change the form of the min-max game, but provide the conditioning information to 
the two players, a conditional GAN can be created from the original GAN game [MO14]: 


: Det. a L a 
a a 5 p(y) “p*(æ|y) [log Do (x, y)] + 2 “p(y) “qo (æ|y) [log(1 a Dg(2, y))] (26.60) 


In the case of implicit latent variable models, the embedding information becomes an additional 
input to the generator, together with the latent variable z: 


1 


: ae a De aa 
a i = FE vw) Epel) flog Dg (x, y)] + 3 Evy) 'q(z) log(1 — Dg(Ge(z,y),y))] (26.61) 


For discrete conditioning information such as labels, one can also add a new loss function, by 
training a critic which does not only learn to distinguish between real and fake data, but learns 
to classify both data and generated samples as pertaining to one of the K classes provided in the 
dataset [OOS17]: 


L.(8, p) ~~ 15 p(y Ugo (aly) [log(Do(y|x))] (26.62) 


a 


aS 


: Le 
‘p*(aly) log Do(yle)] + 5E oy 


Note that while we could have two critics, one unsupervised critic and one supervised which 
maximizes the equation above, in practice the same critic is used, to aid shaping the features used in 
both decision surfaces. Unlike the adversarial nature of the unsupervised game, it is in the interest 
of both players to minmize the classification loss Le. Thus together with the adversarial dynamics 
provided by £, the two players are trained as follows: 


me) —L(8, >) min (0, $) + LeO, $) (26.63) 


In the case of conditional latent variable models, the latent variable controls the sample variability 
inside the mode specified by the conditioning information. In early conditional GANs, the conditioning 
information was provided as additional input to the discriminator and generator, for example by 
concatenating the conditioning information to the latent variable z in the case of the generator; it 
has been since observed that it is important to provide the conditioning information at various layers 
of the model, both for the generator and the discriminator [DV+17; DSK16] or use a projection 
discriminator [MK18]. 


26.5 Inference with GANs 


Unlike other latent variable models such as Variational Autoencoders, GANs do not define an 
inference procedure associated with the generative model. To deploy the principles behind GANs to 
find a posterior distribution p(z|x), multiple approaches have been taken, from combining GANs and 
Variational Autoencoders via hybrid methods [MNG17a; Sri+17; Lar+16; Mak+15b] to constructing 
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Figure 26.10: Learning an implicit posterior using an adversarial approach, as done in BiGAN. From Figure 
1 of [DKD16]. Used with kind permission of Jeff Donahue. 


inference methods catered to implicit variable models [Dum+16; DKD16; DS19]. An overview of 
these methods can be found in [Hus17b]. 

GAN based methods which perform inference and learn implicit posterior distribution p(z|x) 
introduce changes to the GAN algorithm to do so. An example of such a method is BiGAN 
(bidirectional GAN) [DKD16] or ALI (adversarialy learned inference) [Dum+16], which trains an 
implicit parametrized encoder E¢ to map input x to latent variables z. To ensure consistency between 
the encoder E¢ and the generator Gg, an adversarial approach is introduced with a discriminator 
Dg learning to distinguish between pairs of data and latent samples: Dg learns to consider pairs 
(x, E¢(a)) with x ~ p* as real, while (Go(z),z) with z ~ q(z) is considered fake. This approach, 
shown in Figure 26.10, ensures that the joint distributions are matched, and thus the marginal 
distribution qo(a) given by Gg should learn p*(a), while the conditional distribution p¢(z|a) given by 


~ E¢ should learn qo(z|x) = “> x go(a|z)q(z). This joint GAN loss can be used both to train the 


qo (x) 
generator Ge and the encoder E¢, without requiring a reconstruction loss common in other inference 


methods. While not using a reconstruction loss, this objective retains the property that under 
global optimality conditions the encoder and decoder are inverses of each other: Eo(Ge(z)) = z and 
G¢(Ee(a)) = x. (See also Section 21.2.7 for a discussion of how VAEs learn to ensure p*(x)pe(z|xz) 
matches p(z)pe(x|z) using an explicit model of the data.) 


37 26.6 Neural architectures in GANs 


We have so far discussed the learning principles, algorithms, and optimization methods that can 
be used to train implicit generative models parametrized by deep neural networks. We have not 
discussed, however, the importance of the choice of neural network architectures for the model and 
the critic, choices which have fueled the progress in GAN generation since their conception. We will 
look at a few case studies which show the importance of information about data modalities into 
the critic and the generator (Section 26.6.1), employing the right inductive biases (Section 26.6.2), 
incorporating attention in GAN models (Section 26.6.3), progressive generation (Section 26.6.4), 
regularization (Section 26.6.5) and using large scale architectures (Section 26.6.6). 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


5 IO lœ IN ID o e lw N Ie 


Io IN ls la IE le Ie IE 


IS le 


26.6. NEURAL ARCHITECTURES IN GANS 


26.6.1 The importance of discriminator architectures 


Since the discriminator or critic is rarely optimal — either due to the use of alternating gradient 
descent or the lack of capacity of the neural discriminator — GANs do not perform distance or 
divergence minimization in practice. Instead, the critic acts as part of a learned loss function 
for the model (the generator). Every time the critic is updated, the loss function for the generative 
model changes; this is in stark contrast with divergence minimization such maximum likelihood 
estimation, where the loss function stays the same throughout the training of the model. Just as 
learning features of data instead of handcrafting them is a reason for the success of deep learning 
methods, learning loss functions advanced the state of the art of generative modeling. Critics that 
take data modalities into account — such as convolutional critics for images and recurrent critics 
for sequential data such as text or audio — become part of data modality dependent loss functions. 
This in turn provides modality-specific learning signal to the model, for example by penalizing blurry 
images and encouraging sharp edges, which is achieved due to the convolutional parametrization of 
the critic. Even within the same data modality, changes to critic architectures and regularization 
have been one of the main drivers in obtaining better GANs, since they affect the generator’s loss 
function, and thus also the gradients of the generator and have a strong effect on optimization. 


26.6.2 Architectural inductive biases 


While the original GAN paper used convolutions only sparingly, Deep Convolutional GAN (DC- 
GAN) [RMC15] performed an extensive study on what architectures are most useful for GAN 
training, resulting in a set of useful guidelines that led to a substantial boost in performance. Without 
changing the learning principles behind GANs, DCGAN was able to obtain better results on image 
data by using convolutional generators (Figure 26.11) and critics, using BatchNormalization for both 
the generator and critic, replacing pooling layers with strided convolutions, using ReLU activation 
networks in the generator and LeakyReLU activations in the discriminator. Many of these principles 
are still in use today, for larger architectures and with various loss functions. Since DCGAN, residual 
convolutional layers have become a key staple of both models and critics for image data [Gul+17], 
and recurrent architectures are used for sequence data such as text [SSG18b; Md+19]. 


26.6.3 Attention in GANs 


Attention mechanisms are explained in detail in Section 16.2.7. In this section, we discuss how to 
use them for both the GAN generator and discriminator; this is called the Self Attention GAN or 
SAGAN model [Zha+19c]. The advantage of self attention is that it ensures that both discriminator 
and generator have access to a global view of other units of the same layer, unlike convolutional 
layers. This is illustrated in Figure 26.13, which visualizes the global span of attention: query points 
can attend to various other areas in the image. 

The self-attention mechanism for convolutional features reshaped to h € R©*% is defined by 
f=Wrh,g =W,h,S = f'g, where Wre ROXO, W; € ROXO, where C’ < C is a hyperparameter. 
From S € RNN, a probability row matrix 8 is obtained by applying the softmax operator for each 
row, which is then used to attend to a linear transformation of the features o = W, (Wh) 8T € ROXN, 
using learned operators Wp € RC’ *°, W, € RCXC', An output is then created by y = yo+ h, where 
y € R is a learned parameter. 
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Figure 26.11: DCGAN convolutional generator. From Figure 1 of [RMC15]. Used with kind permission of 
Alec Radford. 
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= Figure 26.12: DCGAN convolutional discrimiantor. From Figure 1 of [RMC15]. Used with kind permission 
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Beyond providing global signal to the players, it is worth noting the flexibility of the self attention 
mechanism. The learned parameter y ensures that the model can decide not to use the attention 
layer, and thus adding self attention does not restrict the set of possible models an architecture 
can learn. Moreover, self attention significantly increases the number of parameters of the model 
(each attention layer introduced 4 learned matrices Wy, W,, Wh, Wo), an approach that has been 
observed as a fruitful way to improve GAN training. 


~ 26.6.4 Progressive generation 


One of the first successful approaches to generating higher resolution, color images from a GAN 
is via an iterative process, by first generating a lower dimensional sample, and then using that as 


46 conditioning information to generate a higher dimensional sample, and repeating the process until the 
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26.6. NEURAL ARCHITECTURES IN GANS 


Figure 26.18: Attention queries used by a SAGAN model, showcasing the global span of attention. Each row 
first shows the input image and a set of color coded query locations in the image. The subsequent images 
show the attention maps corresponding to each query location in the first image, with the query color coded 
location being shown, and arrows from it to the attention map are used to highlight the most attended regions. 
From Figure 1 of [Zha+19c]. Used with kind permission of Han Zhang. 


Figure 26.14: LapGAN generation algorithm: the generation process starts with a low dimension sample, 
which gets upscaled and residually added to the output of a generator at a higher resolution. The process gets 
repeated multiple times. From Figure 1 of [DCF +15]. Used with kind permission of Emily Denton. 


desired resolution is reached. LapGAN [DCF+15] uses a Laplacian pyramid as the iterative building 
block, by first upsampling the lower dimensional samples using a simple upsampling operation, such 
as smoothed upsampling, and then using a conditional generator to produce a residual to be added 
to the upsampled version to produce the higher resolution sample. In turn, this higher resolution 
sample can then be provided to another LapGAN layer to produce another, even higher resolution 
sample, and so on - this process is shown in Figure 26.14. In LapGAN, a different generator and 
critic are trained for each iterative block of the model; in ProgressiveGAN [Kar+18] the lower 
resolution generator and critic are “grown”, by becoming part of the generator and critic used to 
learn to generate higher resolution samples. The higher resolution generator is obtained by adding 
new layers on top of the last layer of the lower resolution generator. A residual connection between 
an upscaled version of the lower dimensional sample and the output of the newly created higher 
resolution generator is added, which is annealed from 0 to 1 in training - transitioning from using the 
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Figure 26.15: ProgressiveGAN training algorithm. The input to the discriminator at the bottom of the figure 
is either a generated image, or a real image (denotes as ‘Reals’ in the figure) at the corresponding resolution. 
From Figure 1 of [Kar+18]. Used with kind permission of Tero Karras. 


upscaled version of the lower dimensional sample early in training, to only using the sample of the 
higher resolution generator at the end of training. A similar change is done to the discriminator. 
Figure 26.15 shows the growing generator and discriminators in ProgressiveGAN training. 


26.6.5 Regularization 


Regularizing both the discriminator and the generator has by now a long tradition in GAN training. 


27 Regularizing GANs can be justified from multiple perspectives: theoretically, as it has been shown to 
28 be tied to convergence analysis [MGN18b]; empirically, as it has been shown to help performance and 


stability in practice [RMC15; Miy+18c; Zha+19c; BDS18]; and intuitively, as it can be used to avoid 
overfitting in the discriminator and generator. Regularization approaches include adding noise to the 


31 discriminator input [AB17], adding noise to the discriminator and generator hidden features [ZML16], 


using BatchNorm for the two players [RMC15], adding dropout in the discriminator [RMC15], 
Spectral Normalization [Miy+18c; Zha+19c; BDS18], gradient penalties — penalizing the norm of 
the discriminator gradient with respect to its input |[VeD¢(a)|l? by adding a regularization term to 
the loss function [Arb+18; Fed+18; ACB17; Gul+17]. Often regularization methods help training 
regardless of the type of loss function used, and have been shown to have effects both on training 
performance as well as the stability of the GAN game. However, improving stability and improving 
performance in GAN training can be at odds with each other, since too much regularization can 
make the models very stable but reduce performance [BDS18]. 


= 26.6.6 Scaling up GAN models 


43 By combining many of the architectural tricks discussed thus far — very large residual networks, self 
44 attention, spectral normalization both in the discriminator and the generator, BatchNormalization 


in the generator — one can train GANs to generating diverse, high quality data, as done with 


46 BigGAN [BDS18], StyleGAN [Kar+20c], and Alias-Free GAN [Kar+21]. Beyond combining carefully 
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Figure 26.16: Increasingly realistic synthetic faces generated by different kinds of GAN, specifically (from 
left to right): original GAN [Goo+14], DCGAN [RMC15], CoupledGAN [LT16], ProgressiveGAN [Kar+18], 
StyleGAN [KLA19]. Used with kind permission of Ian Goodfellow. An online demo, which randomly generates 
face images using StyleGAN, can be found at https: // thispersondoesnotexist. com. 


chosen architectures and regularization, creating large scale GANs also require changes in optimization, 
with large batch sizes being a key component. This furthers the view that the key components of the 
GAN game — the losses, the parametrization of the models, and optimization have to be viewed 
collectively rather than in isolation. 


26.7 Applications 


The ability to generate new plausible data enables a wide range of applications for GANs. This 
section will look at a set of applications that aim to demonstrate the breadth of GANs across 
different data modalities: images (Section 26.7.1), video (Section 26.7.2), audio (Section 26.7.3) and 
text (Section 26.7.4), and include applications such as imitation learning (Section 26.7.5), domain 
adapation (Section 26.7.6) and art (Section 26.7.7). 


26.7.1 GAWNs for image generation 


The most widely studied application area is in image generation. Image generation can take 
various forms, of which we cover the translation of one image to another using either paired or 
unpaired data sets. There are many other topics related to image GANs that we do not cover, 
and a more complete overview can be found in other sources, such as [Gool6] for the theory 
and [Bro19] for the practice. A JAX notebook which uses a small pre-trained GAN to generate 
some face images can be found at GAN JAX CelebA_demo.ipynb. PyTorch libraries for fitting 
more advanced GAN models can be found at https: //github.com/open-mmlab/mmgeneration and 
https: //github.com/POSTECH-CVLab/PyTorch-StudioGAN. We show the progression of quality in 
sample generation of faces using GANs in Figure 26.16. There is also increasing need to consider the 
generation of images with regards to the potential risks they can have when used in other domains, 
which involve discussions of synthetic media and deep fakes, and sources for discussion include 
[Bru+18; Wit]. 
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26.7.1.1 Conditional image generation 


Class-conditional image generation using GANs has become a very fruitful endeavor. BigGAN [BDS18] 
carries out class-conditional generation of ImageNet samples across a variety of categories, from 
dogs and cats to volcanoes and hamburgers. StyleGAN [KLA19] is able to generate high quality 
images of faces at high resolution by learning a conditioning style vector and the ProgressiveGAN 
architecture discussed in Section 26.6.4. By learning the conditioning vector they are able to generate 
samples which interpolate between the styles of other samples, for example by preserving coarser 
style elements such as pose or face shape from one sample, and smaller scale style elements such as 
hair style from another; this provides fine grained control over the style of the generated images. 


26.7.1.2 Paired image-to-image generation 


We have discussed in Section 26.4 how using paired data of the form (£n, Yn) can be used to build 
conditional generative models of p(a|y). In some cases, the conditioning variable y has the same 
size and shape as the output variable x. The resulting model pg(a|y) can then be used to perform 
image to image translation, as illustrated in Figure 26.17, where y is drawn from the source 
domain, and æ from the target domain. Collecting paired data of this form can be expensive, 
but in some cases, we can acquire it automatically. One such example is image colorization, where 
a paired dataset can easily be obtained by processing color images into grayscale images (see e.g., 
[Jas]). 

A conditional GAN used for paired image-to-image translation was proposed in [Iso+17], and 
is known as the pix2pix model. It uses a U-net style architecture for the generator, as used for 
semantic segmentation tasks. However, they replace the batch normalization layers with instance 
normalization, as in neural style transfer. 

For the discriminator, pix2pix uses a patchGAN model, that tries to classify local patches as 


27 being real or fake (as opposed to classifying the whole image). Since the patches are local, the 
28 discriminator is forced to focus on the style of the generated patches, and ensure they match the 
29 statistics of the target domain. A patch-level discriminator is also faster to train than a whole-image 


discriminator, and gives a denser feedback signal. This can produce results similar to Figure 26.17 


31 (depending on the dataset). 


26.7.1.3  Unpaired image-to-image generation 


A major drawback of conditional GANs is the need to collect paired data. It is often much easier 
to collect unpaired data of the form Dy = {£n : n = 1 : Nz} and Dy = {y,:n=1: Ny}. For 
example, Dy might be a set of daytime images, and Dy a set of night-time images; it would be 
impossible to collect a paired dataset in which exactly the same scene is recorded during the day and 
night (except using a computer graphics engine, but then we wouldn’t need to learn a generator). 
We assume that the datasets D, and Dy come from the marginal distributions p(x) and p(y) 


41 respectively. We would then like to fit a joint model of the form p(a, y), so that we can compute 


conditionals p(æ|y) and p(y|a) and thus translate from one domain to another. This is called 


43 unsupervised domain translation. 


In general, this is an ill-posed problem, since there are an infinite number of different joint 
distributions that are consistent with a set of marginals p(a) and p(y). We can try, however, to learn 


46 a joint distribution such that samples from it satisfy additional constraints. For example, if G is a 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ% IN ID o [A [wo N e 


Io IN ls la le le Ie IE 


IS Is 


26.7. APPLICATIONS 


BW to Color 


input I output : input output 
Edges to Photo 


aa 


input output 


Labels to Facade 


Labels to Street Scene 


output : input output 


Figure 26.17: Example results on several image-to-image translation problems as generated by the pix2pix 
conditional GAN. From Figure 1 of [Iso+17]. Used with kind permission of Philip Isola. 


conditional generator that maps a sample from ¥ to VY, and F maps a sample from Y to 4, it is 
reasonable to require that these be inverses of each other, i.e., F(G(a)) = x and G(F(y)) = y. This 
is called a cycle consistency loss [Zhu+17]. We can encourage G and F to satisfy this constraint 
by using a penalty term on the difference between the starting image and the image we get after 
going through this cycle: 


Leyete = Ep(ay||F(G(@)) — æl + Epo IGF (y)) — ylh (26.64) 


To ensure that the outputs of G are samples from p(y) and those of F are samples from p(x), we 
use a standard GAN approach, introducing discriminators Dx and Dy, which can be done using 
any choice of GAN loss Laan, as visualized in Figure 26.18 . Finally, we can optionally check that 
applying the conditional generator to images from its own domain does not change them: 


Lidentity = Ep(a)||@ — F (æ) + Epy|ly — Gv) (26.65) 


We can combine all three of these consistency losses to train the translation mappings F and G, 
using hyperparameters A, and A2: 


L = Laan + AiLeycle T A2Lidentity (26.66) 


CycleGAN results on various datasets are shown in Figure 26.19. The bottom row shows how 
CycleGAN can be used for style transfer. 


26.7.2 Video generation 


The GAN framework can be expanded from individual images (frames) to videos; the techniques 
used to generate realistic images can also be applied to generate videos, with additional techniques 
required to ensure spatio-temporal consistency. Spatio-temporal consistency is obtained by ensuring 
that the discriminator has access to the real data and generated sequences in order, thus penalizing 
the generator when generating realistic individual frames without respecting temporal order [SMS17; 
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Figure 26.18: Illustration of the CycleGAN training scheme. (a) Illustration of the 4 functions that are 
trained. (b) Forward cycle consistency from X back to X. (c) Backwards cycle consistency from Y back to Y. 
From Figure 3 of [Zhu+17]. Used with kind permission of Jun- Yan Zhu. 
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Figure 26.19: Some examples of unpaired image-to-image translation generated by the CycleGAN model. 


29 From Figure 1 of [Zhu+17]. Used with kind permission of Jun-Yan Zhu. 


32 Sai+20; CDS19; Tul+18]. Another discriminator can be employed to additionally ensure each frame 
33 is realistic [Tul+18; CDS19]. The generator itself needs to have a temporal element, which is often 
34 implemented through a recurrent component. As with images, the generation framework can be 
35 expanded to video-to-video translation [Ban+18; Wan+18], encompassing applications such as motion 
36 transfer [Cha+19al. 


— 26.7.3 Audio generation 


40 Generative models have been demonstrated in the tasks of generating audio waveforms, as well 
41 as for the task of text-to-speech (TTS) generation. Other types of generative models, such as 
42 autoregressive models, such as WaveNet [oor+16] and WaveRNN [Kal+18b] have been developed for 
43 these applications, although autoregressive models are difficult to parallelize over time since they 
44 predict each time step of the audio sequentially and can be computationally expensive and too slow 
45 to be used in practice. GANs provide an alternative approach for these tasks and other paths for 
46 addressing these concerns. 
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Many different GAN architectures have been developed for audio-only generation, including 
generation of single note recordings from instruments by GANSynth, a vocoder model that uses 
GANs to generate magnitude spectrograms from mel-spectrograms [Eng+18], in voice conversion 
using a modified CycleGAN discussed above [Kan-+20], and the direct generation of raw audio in 
WaveGAN [DMP18]. 

Initial work on GANs for TTS was developed [Yan+17] whose approach is similar to conditional 
GANSs for image generation (see Section 26.7.1.2), but uses 1d convolution instead of 2d. More 
recent GANs such as GAN-TTS [Bin+19] use more advanced architectures and discriminators that 
operate at multiple frequency scales that have performance that now matches the best performing 
autoregressive models when assessed using mean opinion scores. In both the direct-audio generation, 
the ability of GANs to allow faster generation and different types of context is the advantage that 
makes them advantageous compared to other models. 


26.7.4 Text generation 


Similar to image and audio domains, there are several tasks for text data for which GAN-based 
approaches have been developed, including conditional text generation and text-style transfer. Text 
data are often represented as discrete values, at either the character level or the word-level, indicating 
membership within a set of a particular vocabulary size (alphabet size, or number of words). Due to 
the discrete nature of text, GAN models trained on text are explicit, since they explicitly model the 
probability distribution of the output, rather than modeling the sampling path. This is unlike most 
GAN models of continuous data such as images that we have discussed in the chapter so far, though 
explicit GANs of continuous data do exist [Die+19b]. 

The discrete nature of text is why maximum likelihood is one of the most common methods of 
learning generative models of text. However, models trained with maximum likelihood are are often 
limited to autoregressive models, while like in the audio case, GANs make it possible to generate 
text in a non-autoregressive manner, making other tasks possible, such as one-shot feedforward 
generation [Gul+17]. 

The difficulty of generating discrete data such as text using GANs can be seen looking at their 
loss function - examples in Equations (26.19), (26.21) and (26.28). GAN losses contain terms of 
the form E,,(2)f(#), which we not only need to evaluate, but also backpropagate through, by 
computing VgE,, 2) f(x). In the case of implicit distributions given by latent variable models, we 
used the reparametrization trick to compute this gradient (Equation 26.49). In the discrete case, the 
reparametrization trick is not available and we have to look for other ways to estimate the desired 
gradient. One approach is to use the score function estimator, discussed in Section 6.5.3. However, 
the score function estimator exhibits high gradient variance, which can destabilize training. One 
common approach to avoid this issue is to pre-train the language model generator using maximum 
likelihood, and then to fine-tune with a GAN loss which gets backpropagated into the generator 
using the score-function estimator, as done by Sequence GAN [Yu-+17], MaliGAN [Che+17], and 
RankGAN [Lin+17a]. While these methods spearheaded the use of GANs for text, they do not 
address the inherent instabilities of score function estimation and thus have to limit the amount of 
adversarial fine tuning to a small number of epochs and often use a small learning rate, keeping their 
performance close to that of the maximum-likelihood solution [SSG18a; Cac+18]. 

An alternative to maximum likelihood pretraining is to use other approaches to stabilize the score 
function estimator or to use continuous relaxations for backpropagation. ScratchGAN is a word-level 
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model that uses large batch sizes and discriminator regularization to stabilize score function training 
(these techniques are the same that we have seen as stabilizers for training image GANs) [Md+19]. 
[Pre+17b] completely avoid the score function estimator and develop a character level model without 
pre-training, by using continuous relaxations and curriculum learning. These training approaches 
can also benefit from other architectural advances, e.g.,[NNP19] showed that language GANs can 
benefit from complex architectures such as Relation Networks [San+17]. 

Finally, unsupervised text style transfer, mimicking image style transfer, have been proposed by 
[She+17; Fu+17| using adversarial classifiers to decode to a different style/language, or like [Pra+18] 
who trains different encoders, one per style, by combining the encoder of a pre-trained NMT and 
style classifiers, among other approaches. 


26.7.5 Imitation Learning 


Imitation learning takes advantage of observations of expert demonstrations to learn action policies 
and reward functions of unknown environments by minimizing some form of discrepancy between 
the learned and the expert behaviors. There are many approaches available, including behavioral 
cloning [PPG91] that treats this problem as one of supervised learning, and inverse reinforcement 
learning [NROOb]. GANs are appealing for imitation learning since they provide a way to avoid the 
difficulty of designing good discrepancy functions for behaviors, and instead learn these discrepancy 
functions using a discriminator between trajectories generated by a learned agent and observed 
demonstrations. 

This approach, known as Generative Adversarial Imitation Learning (GAIL) [HE16a] demonstrates 
the ability to use GANs for complex behaviors in high-dimensional environments. GAIL jointly 
learns a generator, which forms a stochastic policy, along with a discriminator that acts as a reward 
signal. Like we saw in the probabilistic development of GANs in the earlier sections, GAIL can 


27 also be generalized to multiple f-divergences, rather than the standard Jensen-Shannon divergence 
28 used as the standard loss in GANs. This has lead to a family of other GAIL variants that use 


other f-divergences [Ke+19a; Fin+16; Bro+20c], including f-GAIL that aims to also learn the 
best f-divergence to use [Zha+20e], as well as new analytical insight into the computation and 


31 generalization of such approaches [Che-+20b]. 


26.7.6 Domain Adaptation 


An important task in machine learning is to correct for shifts in the data distribution over time, 
minimizing some measure of domain shift, as we discuss in Section 19.5.3. Like with the other 
applications, GANs are popular as ways of avoiding the choice of distance or degree of shift. Both 
the supervised and unsupervised approaches for image generation we reviewed earlier looked at pixel- 
level domain adaptation models that perform distribution alignment in raw pixel space, translating 
source data to the style of a target domain, as with pix2pix and CycleGAN. Extensions of these 


41 approaches for the general problem of domain adaptation seek to do this not only in the observed 


data space (e.g., with pixels), but also at the feature level. One general approach is domain- 


43 adversarial training of neural networks [Gan+16b] or adversarial discriminative domain adaptation 


(ADDA) [Tze+17]; The CyCADA approach of [Hof+18] extends CycleGAN by enforcing both 
structural and semantic consistency during adaptation using a cycle-consistency loss and semantics 


46 losses based on a particular visual recognition task. There are also many extensions that include 
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26.7. APPLICATIONS 


class conditional information [Tsa+18; Lon+18] or adaptation when the modes to be matched have 
different frequencies in the source and target domains [BHC19]. 


26.7.7 Design, Art and Creativity 


Generative models, particularly of images, have added to approaches in the more general area of 
algorithmic art. The applications in image and audio generation with transfer can also be considered 
aspects of artistic image generation. In these cases, the goal of training is not generalization, but to 
create appealing images across different types of visual aesthetics [Sar18]. One example takes style 
transfer GANs to create visual experiences, in which objects placed under a video are re-rendered 
using other visual styles in real time [AFG19]. The generation ability has been used to explore 
alternative designs and fabrics in fashion [Kat+19], and have now also become part of major drawing 
software to provide new tools to support designers [Ado]. And beyond images, creative and artistic 
expression using GANs include areas in music, voice, dance, and typography [AI 19]. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


PART V 


Discovery 


Qt Discovery methods: an overview 


27.1 Introduction 


We have seen in Part III how to create probabilistic models that can make predictions about outputs 
given inputs, using supervised learning methods (conditional likelihood maximization). And we have 
seen in Part IV how to create probabilistic models that can generate outputs unconditionally, using 
unsupervised learning methods (unconditional likelihood maximization). However, in some settings, 
our goal is to try to understand a given dataset. That is, we want to discover something “interesting”, 
and possibly “actionable”. Prediction and generation are useful subroutines for discovery, but are not 
sufficient on their own. In particular, although neural networks often implicitly learn useful features 
from data, they are often hard to interpret, and the results can be unstable and sensitive to arbitrary 
details of the training protocol (e.g., SGD learning rates, or random seeds). 

In this part of the book, we focus on learning models that create an interpretable representation of 
high dimensional data. A common approach is to use a latent variable model, in which we make 
the assumption that the observed data x was caused by, or generated by, some underlying (often 
low dimensional) latent factors z, which represents the “true” state of the world. Crucially, these 
latent variables are assumed to be meaningful to the end user of the model. (Thus evaluating such 
models will generally require domain expertise.) 

For example, suppose we want to interpret an image x in terms of an underlying 3d scene, z, 
which is represented in terms of objects and surfaces. The forwards mapping from z to æ is often 
many-to-one, i.e., different latent values, say z and z’, may give rise to the same observation 2, 
due to limitations of the sensor. (This is called perceptual aliasing.) Consequently the inverse 
mapping, from æ to z, is ill-posed. In such cases, we need to impose a prior, p(z), to make our 
estimate well-defined. In simple settings, we can use a point estimate, such as the MAP estimate 


2(x) = argmax p(z|x) = argmax log p(z) + log p(a|z) (27.1) 


In the context of computer vision, this approach is known as vision as inverse graphics or analysis 
by synthesis [KMY04; YK06; Doy+07; MC19]. See Figure 27.1 for an illustration. 

This approach to inverse modeling is widely used in science and engineering, where z represents the 
underlying state of the world which we want to estimate, and æ is just a noisy or partial manifestation 
of this true state. In some cases, we know both the prior p(z|@) and the likelihood p(a|z, 0), and we 
just need to solve the inference problem for z. But more commonly, the model parameters @ are 
also (partially) unknown, and need to be inferred from observable samples D = {æn : n =1: N}. In 
some cases, the structure of the model itself is unknown and needs to be learned. 
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Figure 27.1: Vision as inverse graphics. The agent (here represented by a human head) has to infer the scene 
z given the image x using an estimator. From Figure 1 of [Rao99]. Used with kind permission of Rajesh Rao. 


27.2 Overview of Part V 


In Chapter 28, we discuss simple latent variable models where typically the observed data is a 
fixed-dimensional vector such as x € RP. In Section 29.2 and Chapter 29 we extend these models to 
work with sequences of correlated vectors, x = 21.7, such as speech, video, genomics data, etc. It is 
straightforward to make parts of these model be nonlinear (“deep”), as we discuss. These models can 
also be extended to the spatio-temporal setting. 

The models in Chapters 28 to 29 can all be interpreted as probabilistic graphical models with 


— different kinds of CPDs. In Chapter 30, we discuss how to learn the structure of PGMs from data. 
— In Chapter 31, we discuss non-parametric Bayesian models, which allow us to represent uncertainty 


— about many aspects of a model, such as the number of hidden states, the structure of the model, 
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~ the form of a functional dependency, etc. Thus the complexity of the learned representation can 


grow dynamically, depending on the quantity and quality (informativeness) of the data. This is 


— important when performing discovery tasks, and helps us maintain flexibility while still retaining 
— interpretability. 


In Chapter 32, we discuss representation learning using neural networks. This can be tackled 


— using latent variable modeling, but there are also a variety of other estimation methods one can 
= use. Finally, in Chapter 33, we discuss how to interpret the behavior of a learned (prediction) model 


(typically a neural network). 
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28 Latent factor models 


28.1 Introduction 


A latent variable model (LVM) is any probabilistic model in which some variables are always 
latent or hidden. A simple example is a mixture model (Section 28.2), which has the form p(x) = 
>, plæ|z = k)p(z = k), where z is an indicator variable that specifies which mixture component to 
use for generating x. However, we can also use continuous latent variables, or a mixture of discrete 
and continuous. And we can also have multiple latent variables, which are interconnected in complex 
ways. 

In this chapter, we discuss a very simple kind of LVM that has the following form: 


z ~ p(z) (28.1) 
gz|z ~ Expfam(a|f(z)) (28.2) 


where f(z) is known as the decoder, and p(z) is some kind of prior. We assume that z is a single 
“layer” of hidden random variables, corresponding to a set of “latent factors”. We call these latent 
factor models. In this chapter, we assume the decoder f is a simple linear model; we consider 
nonlinear extensions in Chapter 21. Thus the overall model is similar to a GLM (Section 15.1), 
except the input to the model is hidden. 

We can create a large variety of different “classical” models by changing the form of the prior p(z) 
and/or the likelihood p(a|z), as we show in Table 28.1. We will give the details in the following 
sections. (Note that, although we are discussing generative models, our focus is on posterior inference 
of meaningful latents (discovery), rather than generating realistic samples of data.) 


28.2 Mixture models 


One way to create more complex probability models is to take a convex combination of simple 
distributions. This is called a mixture model. This has the form 


K 
p(x|@) = >; TkPk(2) (28.3) 
k=1 


where px is the k’th mixture component, and a, are the mixture weights which satisfy 0 < mk < 1 
and ys Tr = 1. 

We can re-express this model as a hierarchical model, in which we introduce the discrete latent 
variable z € {1,..., K}, which specifies which distribution to use for generating the output æ. The 
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Model p(z) plæ|z) Section 
FA/PCA N (z|0, I) N (x|Wz, Y) Section 28.3.1 
GMM Xe Cat(clr)N(z|u,, De) N (e|Wz, Y) Section 28.2.4 
MixFA Cat(clar)N (z|0, I) N(a|W.z + pe, Yo) Section 28.3.3.5 
NMF I], Ga(zklax, Bx) TI, Poi(xal exp(w)z))) Section 28.4.1 
Simplex FA (mPCA) _ Dir(zl|a) []_ Cat(xa|Waz) Section 28.4.2 
LDA Dir(z|a) []_ Cat(xa|Wz) Section 28.5 
ICA [],_ Laplace(zalà) Ty 6(ta — wlz) Section 28.6 
Sparse coding [],, Laplace(zx|A) Jla N (xalw)z, o?) Section 28.6.5 


Table 28.1: Some popular “shallow” latent factor models. Abbreviations: FA = factor analysis, PCA = 
principal components analysis, GMM = Gaussian mixture model, NMF = non-negative matriz factorization, 
mPCA = multinomial PCA, LDA = latent Dirichlet allocation, ICA = independent components analysis. 
k= 1: L ranges over latent dimensions, d = 1 : D ranges over observed dimensions. (For ICA, we have the 
constraint that L = D.) 


prior on this latent variable is p(z = k) = a, and the conditional is p(æ|z = k) = p,(x) = p(x|Ox). 
That is, we define the following joint model: 

p(z|0) = Cat(z|m) (28.4) 

plalz = k, 8) = p(alx) (28.5) 


The “generative story” for the data is that we first sample a specific component z, and then we 
generate the observations æ using the parameters chosen according to the value of z. By marginalizing 


26 out z, we recover Equation (28.3): 


K 


K 
p(z|0) = X p(z = k|0)p(a|z = k, 0) = $ m.p(a|9x) (28.6) 
k=1 k=1 


== We can create different kinds of mixture model by varying the base distribution pg, as we illustrate 


“= below. 


~ 28.2.1 Gaussian mixture models (GMMs) 


36 A Gaussian mixture model or GMM, also called a mixture of Gaussians (MoG), is defined 


37 as follows: 


K 
p(x) = X mN (a|py, De) (28.7) 


k=1 


In Figure 28.1 we show the density defined by a mixture of 3 Gaussians in 2d. Each mixture 


43 component is represented by a different set of elliptical contours. If we let the number of mixture 


components grow sufficiently large, a GMM can approximate any smooth distribution over RP. 
GMMs are often used for unsupervised clustering of real-valued data samples £n € RP. This 


46 works in two stages. First we fit the model e.g., by computing the MLE Ê = argmax log p(D|@), where 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID [o [A [wo N e 


Io IN ls la le le Ie IE 


IS Is 


28.2. MIXTURE MODELS 


(b) 


Figure 28.1: A mixture of 3 Gaussians in 2d. (a) We show the contours of constant probability for each 
component in the mixture. (b) A surface plot of the overall density. Adapted from Figure 2.23 of [Bis06]. 
Generated by gmm_plot_ demo.ipynb. 


-2 -2 


(a) (b) 


Figure 28.2: (a) Some data in 2d. (b) A possible clustering using K = 5 clusters computed using a GMM. 
Generated by gnm_ 2d.ipynb. 


D = {z£ : n = 1 : N} (e.g., using EM or SGD). Then we associate each data point x, with a discrete 
latent or hidden variable z,, € {1,..., K} which specifies the identity of the mixture component or 
cluster which was used to generate £n. These latent identities are unknown, but we can compute a 
posterior over them using Bayes rule: 


p(Zn = k|0)p(an|Zn = k, 6) 
Pen = FO \Galen = Wi 0) 


Tnk = P(Zn = klan, 0) = (28.8) 


The quantity rnk is called the responsibility of cluster k for data point n. Given the responsibilities, 
we can compute the most probable cluster assignment as follows: 


Ên = arg max Tng = arg max [log p(£n|zn = k, 0) + log p(zn = kļ0)] (28.9) 


This is known as hard clustering. (If we use the responsibilities to fractionally assign each data 
point to different clusters, it is called soft clustering.) See Figure 28.2 for an example. 
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Figure 28.3: We fit a miature of 20 Bernoullis to the binarized MNIST digit data. We visualize the estimated 
cluster means ft. The numbers on top of each image represent the estimated mixing weights tp. No labels 
were used when training the model. Generated by mix_bernoulli_ em_mnist.ipynb. 


If we have a uniform prior over zn, and we use spherical Gaussians with X+ = I, the hard clustering 
problem reduces to 


tn = argmin ||æn — All (28.10) 


In other words, we assign each data point to its closest centroid, as measured by Euclidean distance. 
This is the basis of the K-means clustering algorithm (see the prequel to this book). 


28.2.2 Bernoulli mixture models 


If the data is binary valued, we can use a Bernoulli mixture model (BMM), also called a mixture 


31 Of Bernoullis, where each mixture component has the following form: 


D D 


p(wlz =k, 0) = | | Ber(yaluax) = | [ oi - par)’ (28.11) 
d=1 d=1 


Here uap is the probability that bit d turns on in cluster k. 
For example, consider fitting a mixture of Bernoullis using K = 20 components to the MNIST 


38 dataset. The resulting parameters for each mixture component (i.e., pẹ and Tk) are shown in 


Figure 28.3. We see that the model has “discovered” a representation of each type of digit. (Some 


40 digits are represented multiple times, since the model does not know the “true” number of classes. 
41 See Section 3.9.1 for information on how to choose the number K of mixture components.) 


~ 28.2.3 Gaussian scale mixtures (GSMs) 


45 A Gaussian scale mixture of GSM [AM74; Wes87| is like an “infinite” mixture of Gaussians, each 
46 with a different scale (variance). More precisely, let x = ez, where z ~ N (0, oĉ) and € ~ p(e). We can 
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28.2. MIXTURE MODELS 


think of this as multiplicative noise being applied to the Gaussian rv z. We have ale ~ N (0, o8). 
Marginalizing out the scale e€ gives 


p(x) = J Neo 08 plod (28.12) 


By changing the prior p(e), we can create various interesting distributions. We give some examples 
below. 

The main advantage of this approach is that it is often computationally more convenient to 
work with the expanded parameterization, in which we explicitly include the scale term €, since, 
conditional on that, the distribution is Gaussian. We use this formulation in Section 6.6.5, where we 
discuss robust regression. 


28.2.3.1 Student t distribution as GSM 


We can represent the Student distribution as a GSM as follows: 


T (x|0,07,v) = J N (z|0, zo”)IG(z| \dz = | N (20, zo?) x7 ?(z|v, 1)dz (28.13) 
0 0 


Vv v 
22 
where IG is the inverse Gamma distribution (Section 2.2.3.4). Thus we can think of the Student as an 
infinite superposition of Gaussians of different widths; marginalizing this out induces a distribution 
with wider tails than a Gaussian with the same variance. This result also explains why the Student 
distribution approaches a Gaussian as the dof gets large, since when v = ov, the inverse Gamma 
distribution becomes a delta function. 


28.2.3.2 Laplace distribution as GSM 


Similarly one can show that the Laplace distribution is an infinite weighted sum of Gaussians, where 
the precision comes from a Gamma distribution: 


2 
Laplace(zx|0, A) = JN Glo,7?)Ga(r2h, Xa (28.14) 


28.2.3.3 Spike and slab distribution 


Suppose € ~ Ber(r). (Note that e? = e, since € € {0,1}.) In this case we have 


L= 5 N (z|0, of€)p(e) = TN (x|0, 02) + (1 — 1) d0(z) (28.15) 
e€{0,1} 
This is known as the spike and slab distribution, since the ôo(x) is a “spike” at 0, and the N (x]|0, o) 
acts like a uniform “slab” for large enough og. This distribution is useful in sparse modeling. 
28.2.3.4 Horseshoe distribution 


Suppose € ~ C;(1), which is the half-Cauchy distribution (see Section 2.2.2.4). Then the induced 
distribution p(x) is called the horseshoe distribution [CPS10]. This has a spike at 0, like the 
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30 Figure 28.4: Example of recovering a clean image (right) from a corrupted version (left) using MAP estimation 
31 with a GMM patch prior and Gaussian likelihood. First row: image denoising. Second row: image deblurring. 


Third row: image inpainting. From [RW15] and [ZW11]. Used with kind permission of Dan Rosenbaum and 
Daniel Zoran. 


36 Student and Laplace distributions, but has heavy tails that do not asymptote to zero. This makes it 
3° useful as a sparsity promoting prior, that “kills off” small parameters, but does not overregularize 
38 large parameters. 


-~ 28.2.4 Using GMMs as a prior for inverse imaging problems 


42 Tn this section, we consider using GMMs as a blackbox density model to regularize the inversion of a 
43 many-to-one mapping. Specifically, we consider the problem of inferring a “clean” image x from a 
44 corrupted version y. We use a linear-Gaussian forwards model of the form 


p(ylz) = N(y|Wa, 07D) (28.16) 
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28.2. MIXTURE MODELS 


where øg? is the variance of the measurement noise. The form of the matrix W depends on the nature 
of the corruption, which we assume is known, for simplicity. Here are some common examples of 
different kinds of corruption we can model in our approach: 


e If the corruption is due to additive noise (as in Figure 28.4a), we can set W =I. The resulting 
MAP estimate can be used for image denoising, as in Figure 28.4b. 


e If the corruption is due to blurring (as in Figure 28.4c), we can set W to be a fixed convolu- 
tional kernel [KF09b]. The resulting MAP estimate can be used for image deblurring, as in 
Figure 28.4d. 


e If the corruption is due to occlusion (as in Figure 28.4e), we can set W to be a diagonal matrix, 
with Os in the locations corresponding to the occluders. The resulting MAP estimate can be used 
for image inpainting, as in Figure 28.4f. 


e If the corruption is due to downsampling, we can set W to a convolutional kernel. The resulting 
MAP estimate can be used for image super-resolution. 


Thus we see that the linear-Gaussian likelihood model is surprisingly flexible. Given the model, 
our goal is to invert it, by computing the MAP estimate £ = argmax p(a|y). However, the problem 
of inverting this model is ill-posed, since there are many possible latent images x that map to the 
same observed image y. Therefore we need to use a prior to regularize the inversion process. 

In [ZW11], they propose to partition the image into patches, and to use a GMM prior of the form 
plx) = Jp plci = k)N (xi|tp, Ux) for each patch i. They use K = 200 mixture components, and 
they fit the GMM on a dataset of 2M clean image patches. 

To compute the MAP mixture component, c*, we can marginalize out x; and use Equation (2.90) 
to compute the marginal likelihood 


ct = argmax p(c)p(yi|c) = argmax p(o) N (y;|Wy,, 071 + WEW!) (28.17) 
We can then approximate the MAP for the latent patch x; by using the approximation 
plxilyi) © p(ailys, G) x N (zil Mes; Le )N(yi|Wai, o°T) (28.18) 


If we know cž, we can compute the above using Bayes rule for Gaussians in Equation (2.82). 
To apply this method to full images, [ZW11] optimize the following objective 


1 
E(a\y) = er \|Wa — y||? — EPLL (£) (28.19) 
o 
where EPLL is the “expected patch log likelihood”, given by 


EPLL(x) = X log p(Pia) (28.20) 


where x; = P;g is the ith patch computed by projection matrix P;. Since these patches overlap, 
this is not a valid likelihood, since it overcounts the pixels. Nevertheless, optimizing this objective 
(using a method called “half quadratic splitting”) works well empirically. See Figure 28.4 for some 
examples of this process in action. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


FIO 10 IN IQ Jo Te IW IN te 


oO 


Ie IN IS a TE IS IS IS 18 le le IR le le a Is Is 15 | 


2 e 
© Io 


892 


parnauwe 
SSSSENSN 


Figure 28.5: Illustration of the parameters learned by a GMM applied to image patches. Each of the 3 panels 
corresponds to a different mixture component k. Within each panel, we show the eigenvectors (reshaped as 
images) of the covariance matrix Hy, in decreasing order of eigenvalue. We see various kinds of patterns, 
including ones that look like the ones learned from PCA (see Figure 28.34), but also ones that look like edges 
and texture. From Figure 6 of [ZW11]. Used with kind permission of Daniel Zoran. 


A more principled solution to the overlapping patch problem is to use a multiscale model, as 
proposed in [PE16]. Another approach, proposed in [F W21], uses Gibbs sampling to combine samples 
from overlapping patches. This approach has the additional advantage of computing posterior samples 
from p(a|y), which can look much better than the posterior mean or mode computed by optimization 
methods. (For example, if the corruption process removes the color from the latent image x to 
create a gray scale y, then the posterior MAP estimate of x will also be a gray scale image, whereas 
posterior samples will be color images.) See also Section 28.3.3.5 where we show how to extend the 
GMM model to a mixture of low rank Gaussians, which lets us directly model images instead of 
image patches. 


28.2.4.1 Why does the method work? 


2 To understand why such a simple model of image patches works so well, note that the log prior for a 
2 single latent image patch æ; using mixture component k can be written as follows: 


log p(aile; = k) = log N(a;|0, Ep) = -£l Opa; + ak (28.21) 


35 where ax is a constant that depends on k but is independent of æ;. Let Xp = Vi AV} be an 
36 eigendecomposition of Xg, where Aķ,a is the d’th eigenvalue of Xp, and Vk,a is the d’th eigenvector. 


Then we can rewrite the above as follows: 


log p(æ;|c; = k) -5 i: (Vk, dzi) + ak (28.22) 


42 Thus we see that the eigenvectors are acting like templates. Each mixture component has a different 
43 set of templates, each with their own weight (eigenvalue), as illustrated in Figure 28.5. By mixing 
44 these together, we get a powerful model for the statistics of natural image patches. (See [ZW12] for 
45 more analysis of why this simple model works so well, based on the “dead leaves” model of image 
46 formation.) 
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28.2.4.2 Speeding up inference using discriminative models 


Although simple and effective, computing f(y) = argmax,, p(x|y) for each image patch can be slow 
if the image is large. However, every time we solve this problem, we can store the result, and build 
up a dataset of (y, f(y)) pairs. We can then train an amortized inference network (Section 10.3.6) 
to learn this y > f(y) mapping, to speed up future inferences, as proposed in [RW15]. (See also 
[Par+19] for further speedup tricks.) 

An alternative approach is to dispense with the generative model, and to train on an artificially 
created dataset of the form (y, x), where æ is a clean natural image, and y = C(x) is an artificial 
corruption of it. We can then train a discriminative model f (y) directly from (y, x) pairs. This 
technique works very well (see e.g., [Luc+18]), but is limited by the form of corruptions C it is trained 
on. This means we need to train a different network for every linear operator W, and sometimes 
even for every different noise level o°. 


28.2.4.3 Blind inverse problems 


In the discussion above, we assumed the forward model had the form p(y|x,0) = N(y|W2a, o7I), 
where W is known. If W is not known, then computing p(a|y) is known as a blind inverse 
problem. 

Such problems are much harder to solve. One approach is to estimate the parameters of the 
forwards model, W, and the latent image, æ, using an EM-like method from a set of images coming 
from the same likelihood function. That is, we alternate between estimating @ = argmax, p(a|y, W) 
in the E step, and estimating W = argmaxw, p(y|&, W) in the M step. Some encouraging results of 
this approach are shown in [Ani+18]. (They use a GAN prior for p(a) rather than a GMM.) 

In cases where we get two independent noisy samples, y; and yo, generated from the same 
underlying image x, then we can avoid having to explicitly learn an image prior p(x), and can instead 
directly learn an estimator for the posterior mode, f(y) = argmax, p(a|y), without needing access 
to the latent image x, by exploiting a form of cycle consistency; see [XC19] for details. 


28.2.5 Using mixture models for classification problems 


It is possible to use mixture models to define the class-conditional density p(a|y = c) in a generative 
classifier. We can then derive the class posterior using Bayes rule: 
ply = c)plæjy =c) _ ply = oplely = c) 
p(y = cz) = = (28.23) 
de ply = c)plzly = c) Z 

where p(y = c) = 7; is the prior on class label c, Z is the normalization constant, and the form of 
plæ|y = c) depends on the kind of data we have. For real-valued features, it is common to use a 
GMM: 


Ke 
(ely = c) = >> ack N (£|He,k, Eck) (28.24) 
k=1 
Using a generative model to perform classification can be useful when we have missing data, since 
we can compute p(x”|y = c) = X pm p(£™, £” |y = c) to compute the marginal likelihood of the 
visible features æ”. It is also useful for semi-supervised learning, since we can optimize the model to 
fit >, log p(z!,, y}) on the labeled data and >>, log p(w“) on the unlabeled data. 
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28.2.5.1 Hybrid generative / discriminative training 


Unfortunately the classification accuracy of generative models of the form p(x, y) can be much worse 
than discriminative (conditional) models of the form p(y|a), since the latter are directly optimized to 
predict the labels given the features, and don’t “waste” capacity on modeling irrelevant details of the 
inputs. (For a more in-depth discussion of generative vs discriminative classifiers, see e.g., [Mur22, 
Sec 9.4].) 

Fortunately it is possible to train generative models in a discriminative fashion, which can close 
the performance gap with conditional models, while maintaining the advantages of generative models. 
In particular, we can optimize the following hybrid objective, proposed in [BT04; Rot+18]: 


L(@) = oy log p(a@n, Yn|O) — oS ext (Yn|&n,; 9) (28.25) 
n=1 cae 1 
Leen (@) Lais(@) 


where 0 < A < 1 controls the tradeoff between generative and discriminative modeling. 
If we have unlabeled data, we can modify the generative loss as shown below: 


N! N” 
Lgen(0) = K X logp(æ},, y)|8) + (1 - s) X log p(a|@) (28.26) 
n=1 n=1 


Here we have introduced an extra trade-off parameter 0 < K < 1 to prevent the unlabeled data from 
overwhelming the labeled data (if N,, >> Nz), as proposed in [Nig+00]. 
An alternative to changing the objective function is to change the model itself, so that we 


~~ parameterize the joint using p(x, y) = p(y|x, 0)p(x|ð), and then define different kinds of joint priors 


p(0, 8); see [LBM06; BLO7a] for details. 


28.2.5.2 Optimization issues 


= To optimize the loss, we need to reparameterize the model into unconstrained form. For the class 
“= prior, we can use T1:c = softmax(71.c), and optimize wrt the logits 7.c. Similarly for the mixture 
= weights a¢1:«. The means Heg are already unconstrained. For the covariance matrices, we will use a 
= diagonal plus low-rank representation, to reduce the number of parameters: 


Dek = diag(de,x) + SeKS} g (28.27) 


38 where Se is an unconstrained D x R matrix, where R < D is the rank of the approximation. (For 
39 numerical stability, we usually add eI to the above expression, to ensure Xe ẹ is positive definite for 
40 all parameter settings.) To ensure positivity of the diagonal term, we can use the softplus transform, 
4l dek = log(1 + exp(de,k)), and optimize wrt the de,k terms. 


— 28.2.5.3 Numerical issues 


A I8 Jà IÈ 


45 To compute the class conditional log likelihood, ¢. = log p(a|y = c), we can use the log-sum-exp 
46 trick to avoid numerical underflow. Define &ek = log &ck and fek = log N (x| Hep, Seck). Then we 
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have 
Lle = log p(aly = c) = log $` etek ter (28.28) 
k 
=e" log $` eter tber—M £ logsumexp({Gck + Ler }e) (28.29) 
k 


where M = maxy Qck + Lek. 
We can use a similar method to compute the posterior over classes. We have 


Teele Teele = Teele 
p(y = cla) = -3 Py ee (28.30) 


where L = max, le, i = le — L, and Z= » meele. This lets us combine the class prior probability 
Te with the scaled class conditional log likelihood le to get the class posterior in a stable way. (We 


can also compute the log normalization constant, log p(a) = log Z = log(Z) + L.) 
To compute a single Gaussian log density, leg = log N (£| Hep, Xek), we need to evaluate log det(Xex) 
and Doi To make this efficient, we can use the matrix determinant lemma to compute 


det(A + SS") = det (I + S'A~1S) det(A) (28.31) 
where A = diag(d) + eI, and the matrix inversion lemma to compute 
(A +SS')-'=A-1— A~'8(1+ STATIS) ISTAT! (28.32) 


(See also the discussion of mixture of factor analysers in Section 28.3.3.) 


28.3 Factor analysis 


In this section, we discuss a simple latent factor model in which the prior p(z) is Gaussian, and the 
likelihood p(a|z) is also Gaussian, using a linear decoder for the mean. This family includes many 
important special cases, such as PCA, as we discuss below. We also briefly discuss some simple 
extensions. 


28.3.1 Factor analysis: the basics 


Factor analysis corresponds to the following linear-Gaussian latent variable generative model: 


p(z) = N(2|Mo, Zo) (28.33) 
p(a|z,0) = N(a|Wz + u, Y) (28.34) 


where W is a D x L matrix, known as the factor loading matrix, and W is a diagonal D x D 
covariance matrix. 
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28.3.1.1 FA as a Gaussian with low-rank plus diagonal covariance 


FA can be thought of as a low-rank version of a Gaussian distribution. To see this, note that the 
induced marginal distribution p(æ|0) is a Gaussian (see Equation (2.90) for the derivation): 


p(x|0) = | Newz + u, Y)N (z| p19, £o)dz (28.35) 


= N (<|W uo + u, ¥ + WEoW') (28.36) 


The first and second moments can be derived as follows: 


[x] = Wmo +p 
Cov [a] = WCov |z] W! + Y = WEW! + Y 


(28.37) 


From this, we see that we can set uo = 0 without loss of generality, since we can always absorb 
W uo into u. Similarly, we can set Xo = I without loss of generality, since we can always absorb a 


correlated prior by using a new weight matrix, W = WE 2 since then 
Cov [2] = WX)W! + Y =WW'+ Y (28.38) 
Finally, we see that we should restrict W to be diagonal, otherwise we could set W = 0, thus ignoring 


the latent factors, while still being able to model any covariance. After these simplifications we have 
the final model: 


p(z) = N (z|0, 1) (28.39) 
p(a|z) = N(a@|W2z + p, 8) (28.40) 


— from which we get 


A JÒ Jẹ Jẹ Jẹ Jẹ Jẹ Jè [th |œ |% |W |W |W |W |W |W 
A [S IA IÈ [è IS [S [S Is 18 IS 18 I& IK IS IS I 


p(a) = N (alu, WW!" + ©) (28.41) 


For example, suppose where L = 1, D = 2 and © = o°I. We illustrate the generative process in 
this case in Figure 28.6. We can think of this as taking an isotropic Gaussian “spray can”, representing 


37 the likelihood p(a|z), and “sliding it along” the 1d line defined by wz + p as we vary the 1d latent 


prior z. This induces an elongated (and hence correlated) Gaussian in 2d. That is, the induced 
distribution has the form p(x) = N (æ|u, ww! +071). 

In general, FA approximates the covariance matrix of the visible vector using a low-rank decompo- 
sition: 


C = Cov [a] = WW' + Y (28.42) 


45 This only uses O(LD) parameters, which allows a flexible compromise between a full covariance 
46 Gaussian, with O(D?) parameters, and a diagonal covariance, with O(D) parameters. 
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28.3. FACTOR ANALYSIS 


Figure 28.6: Illustration of the FA generative process, where we have L = 1 latent dimension generating 
D = 2 observed dimensions; we assume W = o7I. The latent factor has value z € R, sampled from 
p(z); this gets mapped to a 2d offset 6 = zw, where w € R?, which gets added to u to define a Gaussian 
plæ|z) = N (æ|u + ô,o’°I). By integrating over z, we “slide” this circular Gaussian “spray can” along the 
principal component aris w, which induces elliptical Gaussian contours in x space centered on p. Adapted 
from Figure 12.9 of [Bis06]. 


28.3.1.2 Computing the posterior 


We can compute the posterior over the latent codes, p(z|x), using Bayes rule for Gaussians. In 
particular, from Equation (2.82), we have 


p(z|x) = N (z|tzz> Dyin) (28.43) 
Zaje = (1+ WY WwW) =1-W'(wWw' +Y) Ww (28.44) 
Hae = Ezel WE (æ — p)] = W' (WW! + 4) (æ — u) (28.45) 


We can avoid inverting the D x D matrix C = WW! + Ẹ by using the matrix inversion lemma: 


cl=(Ww'+w)! (28.46) 
= '-w'wil+w'e'w)iw'e! (28.47) 
e————_ 
L-1 


where L = I + WTẸIW is Lx L. 


28.3.1.3 Computing the likelihood 


In this section, we discuss how to efficiently compute the log (marginal) likelihood, which is given by 


1 


log p(æ|u, C) = 5 [D log(2m) + log det(C) + #'C~'] (28.48) 
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where @ = æ — u, and C= WW! + P. Using Equation (28.47), the Mahalanobis distance can be 
computed using 


'C lg=a' [Pta — VWL WwW] oa) (28.49) 


which takes O(L3 + LD) to compute. From the matrix determinant lemma, the log determinant is 
given by 


log det(C) = log det (L) + log det (Y) (28.50) 


which takes O(L? + D) to compute. (See also Section 28.2.5, where we discuss fitting low-rank GMM 
classifiers discriminatively, which requires similar computations. ) 


28.3.1.4 Model fitting using EM 


We can compute the MLE for the FA model either by performing gradient ascent on the log likelihood 
in Equation (28.48), or by using the EM algorithm [RT82; GH96b]. The latter can converge faster, 
and automatically satisfies positivity constraints on W. We give the details below, assuming that the 
observed data is standardized, so u = 0 for notational simplicity. 

In the E step, we compute the following expected sufficient statistics: 


N 
Ezz = 5 Ln, ) [zlan]" (28.51) 
n=1 
N 
Ez,- = 5 E[zz" |e, (28.52) 
n=1 
N 
Ese =) oan, (28.53) 
n=1 
where 
i [z|x] = Ba (28.54) 
| [zz" |æ] = Cov [z|æ] + E [zz] [z|a]' = I — BW + Bzz'B' (28.55) 
B £ W' (Y + WW!) = W' C7! (28.56) 


— In the M step, we have 


wrew — Ez E7, (28.57) 


new 1 , new 
ww = -diag (Ez, — WEL ,) (28.58) 


— 28.3.1.5 Handling missing data 


45 We can also perform posterior inference in the presence of missing data (if we make the miss- 
46 ing at random assumption — see Section 21.3.4 for discussion). In particular, let us partition 
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x = (@1,%2), W = [Wi, W2], and u = [u;, Ho], and suppose a2 is missing (unknown). From 
Supplementary Section 2.1.1, we have 


p(z|@1) = N (z|; X21) (28.59) 
£ =I+Wi En Wi (28.60) 
hai = Ean l[WI Er (#1 — 14)] (28.61) 


where X41 is the top left block of W. 
We can modify the EM algorithm to fit the model in the presence of missing data in the obvious 
way. 


28.3.1.6 Unidentifiability of the parameters 


The parameters of a FA model are unidentifiable. To see this, consider a model with weights W and 
observation covariance Ų. We have 


Cov [a] = WE [zz"] WT +E [ee"] = WW'+ © (28.62) 


where € ~ N (0, W) is the observation noise. Now consider a different model with weights W = WR, 
where R is an arbitrary orthogonal rotation matrix, satisfying RRT = I. This has the same likelihood, 
since 


Cov [x] = WE [zz"] W' + E[ee'] = WRR'W' + Y = WW" + Y (28.63) 


Geometrically, multiplying W by an orthogonal matrix is like rotating z before generating x; but 
since z is drawn from an isotropic Gaussian, this makes no difference to the likelihood. Consequently, 
we cannot uniquely identify W, and therefore cannot uniquely identify the latent factors, either. 
This is called the “factor rotations problem” (see e.g., [Dar80]). 

To break this symmetry, several solutions can be used, as we discuss below. 

e Forcing W to have orthogonal columns.. Perhaps the simplest solution to the identifiability 
problem is to force W to have orthogonal columns. This is the approach adopted by PCA. The 
resulting posterior estimate will then be unique, up to permutation of the latent dimensions. 
(In PCA, this ordering ambiguity is resolved by sorting the dimensions in order of decreasing 
eigenvalues of WWT.) 

e Forcing W to be lower triangular. One way to resolve permutation unidentifiability, which 
is popular in the Bayesian community (e.g., [LW04]), is to ensure that the first visible feature is 
only generated by the first latent factor, the second visible feature is only generated by the first 
two latent factors, and so on. For example, if L = 3 and D = 4, the correspond factor loading 
matrix is given by 


(28.64) 
W41 W42 W43 


We also require that wkk > 0 for k = 1 : L. The total number of parameters in this constrained 
matrix is D + DL — L(L — 1)/2, which is equal to the number of uniquely identifiable parameters 
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in FA (excluding the mean).' The disadvantage of this method is that the first L visible variables, 
known as the founder variables, affect the interpretation of the latent factors, and so must be 
chosen carefully. 

e Sparsity promoting priors on the weights. Instead of pre-specifying which entries in W are 
zero, we can encourage the entries to be zero, using ¢; regularization [ZHT06], ARD [Bis99; ABO8], 
or spike-and-slab priors [Rat+09]. This is called sparse factor analysis. This does not necessarily 
ensure a unique MAP estimate, but it does encourage interpretable solutions. 

e Choosing an informative rotation matrix. There are a variety of heuristic methods that try 
to find rotation matrices R which can be used to modify W (and hence the latent factors) so as 
to try to increase the interpretability, typically by encouraging them to be (approximately) sparse. 
One popular method is known as varimax [Kai58]. 

e Use of non-Gaussian priors for the latent factors. If we replace the prior on the latent 
variables, p(z), with a non-Gaussian distribution, we can sometimes uniquely identify W, as well 
as the latent factors. See e.g., [KIKH20] for details. 


28.3.2 Probabilistic PCA 


In this section, we consider a special case of the factor analysis model in which W has orthogonal 
columns and W = o7I, so p(x) = N(æ|u,C) where C = WW! + 07I. This model is called 
probabilistic principal components analysis (PPCA) [TB99], or sensible PCA [Row97]. 

The advantage of PPCA over factor analysis is that the MLE has a closed form solution, as we 
show in Section 28.3.2.2. The advantage of PPCA over non-probabilistic PCA is that the model 
defines a proper likelihood function, which makes it easier to extend in various ways e.g., by creating 
mixtures of PPCA models (see Section 28.3.3). 


27 28.3.2.1 Derivation of the MLE 


59 Lhe log likelihood for PPCA is given by 


N 
ND N 1 = 
log p(X|u, W, 07) = -z log(27r) — D log |C| — 3 X (æn — p)'C7 (£, — 1) (28.65) 
n=1 
The MLE for pm is z. Plugging in gives 
2 N -1 
logp(X|u, W, 0%) = aa [D log(2m) + log |C| + tr(C~*S)] (28.66) 


where S= 4 SÀ (En — T)(£n — T)" is the empirical covariance matrix. 
In [TB99; Row97] they show that the maximum of this objective must satisfy 


W =U, (A, —071)?R (28.67) 


1. We get D parameters for © and DL for W, but we need to remove L(L — 1)/2 degrees of freedom coming from R, 
since that is the dimensionality of the space of orthogonal matrices of size L x L. To see this, note that there are L — 1 


= free parameters in R in the first column (since the column vector must be normalized to unit length), there are L — 2 
46 free parameters in the second column (which must be orthogonal to the first), and so on. 
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28.3. FACTOR ANALYSIS 


where Uz is a D x L matrix whose columns are given by the L eigenvectors of S with largest 
eigenvalues, Ay is the L x L diagonal matrix of corresponding eigenvalues, and R is an arbitrary 
L x L orthogonal matrix, which (WLOG) we can take to be R =I. 

If we plug in the MLE for W, we find the covariance for the predictive distribution to be 


C= WW! +I = Uz (Az — o°I)U} +071 (28.68) 
The MLE for the observation variance is 
1 D 
2 — n 
i=L+1 


which is the average distortion associated with the discarded dimensions. If L = D, then the 
estimated noise is 0, since the model collapses to z = a. 


28.3.2.2 PCA is recovered in the noise-free limit 


In the noise-free limit, where o? = 0, we see that the MLE (for R = I) is 
z 1 
W =U, A; (28.70) 
so 
` A A T ne Oe 
C=WW =U,A; A? U, =S: (28.71) 


where Sz is the rank L approximation to S. This is the same as standard PCA. 


28.3.2.3 Computing the posterior 


To use PPCA as an alternative to PCA, we need to compute the posterior mean E [z|x], which is the 
equivalent of the PCA encoder model. Using the factor analysis results from Section 28.3.1.2, we 
have 


p(z|2) = N(z\o-27 EW! (æ — u), £) (28.72) 
where 
yt =14+07W'W = 5 (I+ W'W) (28.73) 
M 
Hence 
p(z|a) = N(z|M7 W! (æ — u), ° M7!) (28.74) 


In the o? = 0 limit, we have M = W'W and so 


i [z|æ] = (W'W)-!W! (æ — T) (28.75) 


This is the orthogonal projection of the data into the latent space, as in standard PCA. 
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28.3.2.4 Model fitting using EM 


In Section 28.3.2.2, we showed how to fit the PCA model using an eigenvector method. We can also 
use EM, by leveraging the probabilistic formulation of PPCA in the zero noise limit, g? = 0, as 
shown by [Row97]. 

In particular, let Z = Z' be an L x Np matrix storing the posterior means (low-dimensional 
representations) along its columns. Similarly, let Zn = £n — f be the centered examples stored along 
the columns of X. From Equation (28.75), when o? = 0, we have 


Z = (W'w)'w'x (28.76) 


This constitutes the E step. Notice that this is just an orthogonal projection of the data. 
From Equation (28.57), the M step is given by 


We £ Ën el £ l [zn|Ën] E [zn|žn]" (28.77) 


n 


where we exploited the fact that © = Cov [z|%] = OI when o? = 0. 
In summary, here is the entire algorithm: 


Z = (W'Ww)-'w'X (E step) (28.78) 
W = XZ" (ZZ')-! (M step) (28.79) 


It is worth comparing this expression to the MLE for multi-output linear regression, which has the 
form W = (9°, ynx),)(>>,, 2nz},) |. Thus we see that the M step is like linear regression where we 
replace the observed inputs by the expected values of the latent variables. 

[TB99] showed that the only stable fixed point of the EM algorithm is the globally optimal solution. 
That is, the EM algorithm converges to a solution where W spans the same linear subspace as that 
defined by the first L eigenvectors of S. However, if we want W to be orthogonal, and to contain the 
eigenvectors in descending order of eigenvalue, we have to orthogonalize the resulting matrix (which 


z, can be done quite cheaply). Alternatively, we can modify EM to give the principal basis directly 
a5 [AOD3]. 


28.3.3 Mixture of factor analysers 


The factor analysis model (Section 28.3.1) assumes the observed data can be modeled as arising from 


= a linear mapping from a low-dimensional set of Gaussian factors. One way to relax this assumption is 
= to assume the model is only locally linear, so the overall model becomes a (weighted) combination of 
= FA models; this is called a mixture of factor analysers or (MFA) [GH96b]. The overall model for 
= the data is a mixture of linear manifolds, which can be used to approximate an overall curved manifold. 
= Another way to think of this model is a mixture of Gaussians, where each mixture component has a 
= covariance matrix which is diagonal plus low-rank. 


— 28.3.3.1 Model definition 


IS 16 16 IÈ 


45 The generative story is as follows. First we sample a discrete latent indicator m, € {1,..., A} from 
46 discrete distribution Cat(-|7) to specify which subspace (cluster) we should use to generate the data. 
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28.3. FACTOR ANALYSIS 


Figure 28.7: Mixture of factor analyzers as a PGM. 


If Mm, = k, we sample zn from a Gaussian prior and pass it through the W, matrix, where W; 
maps from the L-dimensional subspace to the D-dimensional visible space.” Finally we add Gaussian 
observation noise sampled from N (up, Y). Thus the model is as follows: 


P(LnlZn, Mn = k, 0) = N(an| by + Ween, Y) (28.80) 
P(Zn|9) = N (2n 0, I) (28.81) 
p(mMn|O) = Cat(mn |) (28.82) 


The corresponding distribution in the visible space is given by 


plao) = > (m=) | ple|e)p(a|z,m) dz (28.83) 
k 


= 5 Tk J NGO DNEW,z + Hg, Y) dz (28.84) 
k 

=X mN (z|up, Y + WWI) (28.85) 
k 


In the special case that Y = o7I, we get a mixture of PPCA models. See Figure 28.8 for an example 
of the method applied to some 2d data. 

We can think of this as a low-rank version of a mixture of Gaussians. In particular, this model 
needs O(K LD) parameters instead of the O(K D?) parameters needed for a mixture of full covariance 
Gaussians. This can reduce overfitting. 


28.3.3.2 Model fitting using EM 


We can fit this model using EM, extending the results of Section 28.3.1.4 (see |GH96b] for the derivation, 
and [ZY08] for a faster ECM version). In the E step, we compute the posterior responsibility of 
cluster j for data point 7 using 


A . T 
Tij = pmi = j|xi, 8) xX TiN (xilu;, WW; + Ww) (28.86) 
2. If we allow zn to depend on mn, we can let each subspace have a different dimensionality, as suggested in [KS15]. 
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Figure 28.8: Mixture of PPCA models fit to a 2d dataset, using L = 1 latent dimensions. (a) K = 1 mizture 
components. (b) K = 10 mixture components. Generated by mix_ppca_ demo.ipynb. 


We also compute the following expected sufficient statistics, where we define w; = I (m = j) and 


B; = WI (Y + W, Wt): 


[wjz|x:] = E [w;|x:] E [z|w;, ei] = rijB; (xi — u;) 


l [wjzz" |æ] = i [w; |æ] a [z2" |w;, xi = rij(I- B,; W; } B,(x 


(28.87) 


y;)(@—;)"BT) (28.88) 


In the M step, we compute the following parameter update for the augmented factor loading 


matrix: 


new ,,new F7new me T 
[We u] S Wi = (X rigaesE [zæ w] O ri 


i 


= where ž = [z; 1], 


less] = (Pew) 


ae E, [zz" |æ, wz] 
b [Zz |x; wj] z ( i [ž]æ;, w,]" 


The new covariance matrix is given by 


new 1 .. Anew 
wv = wiias 2 Tij (£i _ W} 


a 


t 


[ž|æ:, ws] ) 


[z\ai, w]e 


And the new mixing weights are given by 


N 
new __ a oo 
m = N Tij 


i=l 


a [ 22" æi, wy] ) 


ic (28.89) 


(28.90) 


(28.91) 


(28.92) 


(28.93) 
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(a) (b) (c) 

(a) (e) (f) 
Figure 28.9: Illustration of estimating the effective dimensionalities in a mixture of factor analysers using 
variational Bayes EM with an ARD prior. Black are negative values, white are positive, gray is 0. The blank 
columns have been forced to 0 via the ARD mechanism, reducing the effective dimensionality. The data was 


generated from 6 clusters with intrinsic dimensionalities of 7,4,3,2,2,1, which the method has successfully 
estimated. From Figure 4.4 of [Bea03]. Used with kind permission of Matt Beal. 


28.3.3.3 Model fitting using SGD 


We can also fit mixture models using SGD, as shown in [RW 18]. This idea can be combined with 
an inference network (see Section 10.3.6) to efficiently approximate the posterior over the latent 
variables. [Zon+18] use this approach to jointly learn a GMM applied to a deep autoencoder to 
provide a nonlinear extension of MFA; they show good results on anomaly detection. 


28.3.3.4 Model selection 


To choose the number of mixture components K, and the number of latent dimensions L, we can 
use discrete search combined with objectives such as the marginal likelihood or validation likelihood. 
However, we can also use numerical optimization methods to optimize L, which can be faster. We 
initially assume that Ne is known. To estimate L, we set the model to its maximal size, and then use 
a technique called automatic relevance determination or ARD to automatically prune out irrelevant 
weights (see Section 15.2.7). This can be implemented using variational Bayes EM (Section 10.2.5); 
for details, see [Bis99; GBOO]. 

Figure 28.9 illustrates this approach applied to a mixture of FA models fit to a small synthetic 
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Figure 28.10: We show the estimated number of clusters, and their estimated dimensionalities, as a function 
of sample size. The ARD algorithm found two different solutions when N = 8. Note that more clusters, with 
larger effective dimensionalities, are discovered as the sample sizes increases. From Table 4.1 of [Bea03]. 
Used with kind permission of Matt Beal. 


dataset. The figures visualize the weight matrices for each cluster, using Hinton diagrams, where 
where the size of the square is proportional to the value of the entry in the matrix. We see that 
many of them are sparse. Figure 28.10 shows that the degree of sparsity depends on the amount of 
training data, in accord with the Bayesian Occam’s razor. In particular, when the sample size is 
small, the method automatically prefers simpler models, but as the sample size gets sufficiently large, 
the method converges on the “correct” solution, which is one with 6 subspaces of dimensionality 1, 2, 


SIS IS 18 le le IR le lale le IS IE Is 


24 2, 3, 4 and 7. 

25 Although the ARD method can estimate the number of latent dimensions L, it still needs to 
26 perform discrete search over the number of mixture components N,. This is done using “birth” and 
27 “death” moves [GB00]. An alternative approach is to perform stochastic sampling in the space of 
28 models. Traditional approaches, such as [LW04], are based on reversible jump MCMC, and also use 
29 birth and death moves. However, this can be slow and difficult to implement. More recent approaches 
30 use non-parametric priors, combined with Gibbs sampling, see e.g., [PC09]. 

31 

~ 28.3.3.5 MixFA for image generation 


w 
R 


34 In this section, we use the MFA model as a generative model for images, following [RW18]. This is 
35 equivalent to using a mixture of Gaussians, where each mixture component has a low-rank covariance 
matrix. Surprisingly, the results are competitive with deep generative models such as those in Part IV, 
despite the fact that no neural networks are used in the model. 

In [RW18], they fit the MFA model to the CelebA dataset, which is a dataset of faces of celebrities 
(movie stars). They use K = 300 components, each of latent dimension L = 10; the observed data 
40 has dimension D = 64 x 64 x 3 = 12,288. They fit the model using SGD, using the methods from 
41 Section 28.3.1.3 to efficiently compute the log likelihood, despite the high dimensionality. The uy 
42 parameters are initialized using K-means clustering, and the Wy parameters are initialized using 
43 factor analysis for each component separately. Then the model is fine-tuned end-to-end. 

Figure 28.11 shows some images generated from the fitted model. The results are suprisingly good 
45 for such a simple locally linear model. The reason the method works is similar to the discussion 
46 in Section 28.2.4.1: essentially the W, matrix learns a set of L-dimensional basis functions for the 


A Je Jẹ Je A Jẹ Je Jè Jw [ew j% [ow jw 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


S IO lœ IN ID o e w N e 


Io IS ls la le le Is IE 


IS Is 


28.3. FACTOR ANALYSIS 


Figure 28.11: Random samples from the MixFA model fit to CelebA. Generated by mix_ppca_celebA.ipynb. 
Adapted from Figure 4 of [RW18]. Used with kind permission of Yair Weiss. 


7 
f j 
(a) 


Figure 28.12: (a) Visualization of the parameters learned by the MFA model. The top row shows the mean 
Hy and noise variance P, reshaped from 12,288-dimensional vectors to 64 x 64 x 3 images, for two misture 
components k. The next 5 rows show the first 5 (of 10) basis functions (columns of Wx) as images. On 
row i, left column, we show up — Wx|[:, i]; in the middle, we show 0.5 + Wz[:, i], and on the right we show 
Hy + Wl: i]. (b) Images generated by computing pyp + z1Wz[:, i] + z2WẸz[:, j], for some component k and 
dimensions i,j, where (z1, z2) are drawn from the grid |—1 : 1,—1 : 1], so the central image is just pọ. From 
Figure 6 of [RW18]. Used with kind permission of Yair Weiss. 


(b) 


subset of face images that get mapped to cluster k. See Figure 28.12 for an illustration. 

There are several advantages to this model compared to VAEs and GANs. First, [RW18], showed 
that this MixFA model captures more of the modes of the data distribution than more sophisticated 
generative models, such as VAEs (Section 21.2) and GANs (Chapter 26). Second, we can compute 
the exact likelihood p(a), so we can compute outliers or unusual images. This is illustrated in 
Figure 28.13. 

Third, we can perform image imputation from partially observed images given arbitrary missingness 
patterns. To see this, let us partition x = (a1, £2), where x1 (of size D1) is observed and 22 (of size 
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Figure 28.13: Samples from the 100 CelebA images with lowest likelihood under the MFA model. Generated 
by miz_ppca_celebA.ipynb. Adapted from Figure 7a of [RW18]. Used with kind permission of Yair Weiss. 


= Figure 28.14: Illustration of image imputation using an MFA. Left column shows 4 original images. Subsequent 
31 pairs of columns show an occluded input, and a predicted output. Generated by mix_ppca_celebA.ipynb. 


32 Adapted from Figure %b of [RW18]. Used with kind permission of Yair Weiss. 
22 Dz = D— Dj) is missing. We can compute the most probable cluster using 
k* = argmax p(c = k)p(ai|c = k) (28.94) 
where 
log p(æ |t, Ck) = —5 [D1 log(2n) + log det(Cr 11) + BC eA ë (28.95) 
44 where Cy 11 is the top left Dı x Dı block of WW] + Wy, and 1 = x; — pu, [1 : D1]. Once we know 


45 which discrete mixture component to use, we can compute the Gaussian posterior p(z|a1,k*) using 
46 Equation (28.59). Let 2 = E[z|a,,k*]. Given this, we can compute the predicted output for the full 
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Figure 28.15: Gaussian latent factor models for paired data. (a) Supervised PCA. (b) Partial least squares. 


image: 
B= WZ + py» (28.96) 


We then use the estimate a’ = [a1, £2], so the observed pixels are not changed. This is an example of 
image imputation, and is illustrated in Figure 28.14. Note that we can condition on an arbitrary 
subset of pixels, and fill in the rest, whereas some other models (e.g., autoregressive models) can only 
predict the bottom right given the the top left (since they assume a generative model which works in 
raster-scan order). 


28.3.4 Factor analysis models for paired data 


In this section, we discuss linear-Gaussian factor analysis models when we have two kinds of observed 
variables, x € RP= and y € RP”, which are paired. These often correspond to different sensors or 
modalities (e.g., images and sound). We follow the presentation of [Vir10]. 


28.3.4.1 Supervised PCA 


If we have two observed signals, we can model the joint p(a,y) using a shared low-dimensional 
representation using the following linear Gaussian model: 


P(Zn) = N(2n|0, Iz) (28.97) 
P(Ln|Zn, 0) = N (an|Wa2n,o2Ip,) (28.98) 
P(Yalzn, 0) = N (Yn|Wyzn, opip) (28.99) 


This is illustrated as a graphical model in Figure 28.15a. The intuition is that z, is a shared latent 
subspace, that captures features that x, and y, have in common. The variance terms oz and oy 
control how much emphasis the model puts on the two different signals. 

The above model is called supervised PCA [Yu+06]. If we put a prior on the parameters 
0 = (Wz, Wy, 0x, Cy), it is called Bayesian factor regression [Wes03]. 
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We can marginalize out Zn to get p(yn|an,@). If Yn is a scalar, this becomes 


P(Yn|@n, 0) = N (yn |ay,v, wy Cwy +02) (28.100) 
C = (I+0,?WiW,)7* (28.101) 
v = o° CW, wy (28.102) 


To apply this to the classification setting, we can replace the Gaussian p(y|z) with a logistic 
regression model: 


PlYn|zn, 0) = Ber(yn|o(wyzn)) (28.103) 


In this case, we can no longer compute the marginal posterior predictive p(yn|£n, 0) in closed form, 
but we can use techniques similar to exponential family PCA (see [Guo09] for details). 

The above model is completely symmetric in a and y. If our goal is to predict y from a via 
the latent bottleneck z, then we might want to upweight the likelihood term for y, as proposed in 
[Ris+08]. This gives 


p(X, Y, Z|@) = p(Y|Z, W,,)p(X|Z, Wz) °p(Z) (28.104) 
where a < 1 controls the relative importance of modeling the two sources. The value of a can be 
chosen by cross-validation. 
28.3.4.2 Partial least squares 


We now consider an asymmetric or more “discriminative” form of supervised PCA. The key idea is to 
allow some of the (co)variance in the input features to be explained by its own subspace, z7, and to 
let the rest of the subspace, z7, be shared between input and output. The model has the form 


p(zi) = N (z710, In, )N (z710, Iz.) (28.105) 
plyilzi) =N (Wz? + u, 0°Ip,) (28.106) 
plæilzi) = N(Wezi + B227 + Ha, 0°Ip,) (28.107) 

32 See Figure 28.15b. The corresponding induced distribution on the visible variables has the form 
p(v;|@) = J NoWa + p, PIN (z;|0, I)dzi = N (vilu, WWT + 071) (28.108) 
where v; = (yi; £i), H = (Hy; Hz) and 
_({W, 0 
W= ( ; B.) (28.109) 
ww = (WW Wis (28.110) 
=(w,w! w,Ww!+B,BT 


We should choose L large enough so that the shared subspace does not capture covariate-specific 
variation. 
MLE in this model is equivalent to the technique of partial least squares (PLS) [Gus01; Nou+02; 


46 Sun+09]. This model can be also be generalized to discrete data using the exponential family [Vir10]. 
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28.3. FACTOR ANALYSIS 


Figure 28.16: Canonical correlation analysis as a PGM. 


28.3.4.3 Canonical correlation analysis 


We now consider a symmetric unsupervised version of PLS, in which we allow each view to have 
its own “private” subspace, but there is also a shared subspace. If we have two observed variables, 
x; and y;, then we have three latent variables, zê € R’* which is shared, z7 € R“« and zł € R4» 
which are private. We can write the model as follows [BJ05]: 


p(zi) =N (210, 17, (2710, Iz, )W (270, I, ) (28.111) 
plæilzi) = N(ai|Boz? + Waz? + ua, 0°Ip,) (28.112) 
plyilzi) = N (y:|Byz} + Wyz? + by, 07 Ip, ) (28.113) 


See Figure 28.16 The corresponding observed joint distribution has the form 


p(v;|0) = J NoWa + u, 07 DN (z;|0, I)dz; = N (viu, WW' + 07Ip) (28.114) 
where 
W= te a B (28.115) 


[BJ05] showed that MLE for this model is equivalent to a classical statistical method known as 
canonical correlation analysis or CCA [Hot36]. However, the PGM perspective allows us to 
easily generalize to multiple kinds of observations (this is known as generalized CCA [Hor61]) or 
to nonlinear models (this is known as deep CCA [WLL16; SNM16]), or exponential family CCA 
[KVK10]. See [Uur+17] for further discussion of CCA and its extensions, and Section 32.2.2.2 for 
more details. 


28.3.5 Factor analysis with exponential family likelihoods 


So far we have assumed the observed data is real-valued, so £„ € R?. If we want to model other 
kinds of data (e.g., binary or categorical), we can simply replace the Gaussian output distribution 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Io Ie lo IN IR 


S là IS 18 IS IR 18 Ie le IS ls la Ie Ie Ie E Is 


912 


Figure 28.17: Exponential family PCA model as a DPGM. 


with a suitable member of the exponential family, where the natural parameters are given by a linear 
function of zn. That is, we use 


P(@n|2n) = exp(T (æ)" 0 + h(x) — g(0)) (28.117) 


where the N x D matrix of natural parameters is assumed to be given by the low rank decomposition 
© = ZW, where Z is N x L and W is L x D. The resulting model is called exponential family 
factor analysis 

Unlike the linear-Gaussian FA, we cannot compute the exact posterior p(zn|£n, W) due to the 
lack of conjugacy between the expfam likelihood and the Gaussian prior. Furthermore, we cannot 


~~ compute the exact marginal likelihood either, which prevents us from finding the optimal MLE. 


[CDS02] proposed a coordinate ascent method for a deterministic variant of this model, known as 


~~ exponential family PCA. This alternates between computing a point estimate of zn and W. This 
~ can be regarded as a degenerate version of variational EM, where the E step uses a delta function 
~ posterior for z,. [GS08] present an improved algorithm that finds the global optimum, and [Ude+16] 
~~ presents an extension called generalized low rank models, that covers many different kinds of 


— loss function. 


However, it is often preferable to use a probabilistic version of the model, rather than computing 
point estimates of the latent factors. In this case, we must represent the posterior using a non- 
degenerate distribution to avoid overfitting, since the number of latent variables is proportional to 
the number of data cases [WCS08]. Fortunately, we can use a non-degenerate posterior, such as a 
Gaussian, by optimizing the variational lower bound. We give some examples of this below. 


28.3.5.1 Example: binary PCA 


41 Consider a factored Bernoulli likelihood: 


p(æ|z) = | | Ber(xalo(w}z)) (28.118) 
d 


Suppose we observe Np = 150 bit vectors of length D = 16. Each example is generated by choosing 


46 one of three binary prototype vectors, and then by flipping bits at random. See Figure 28.18(a) 
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28.3. FACTOR ANALYSIS 


Noisy Binary Data Latent Embedding 


100 


(a) (b) 


Posterior Predictive Reconstruction 


100 


2.5 


(a) 


Figure 28.18: (a) 150 synthetic 16 dimensional bit vectors. (b) The 2d embedding learned by binary PCA, fit 
using variational EM. We have color coded points by the identity of the true “prototype” that generated them. 
(c) Predicted probability of being on. (d) Thresholded predictions. Generated by binary fa demo.ipynb. 


for the data. We can fit this using the variational EM algorithm (see [Tip98] for details). We 
use L = 2 latent dimensions to allow us to visualize the latent space. In Figure 28.18(b), we plot 


z |ia WI . We see that the projected points group into three distinct clusters, as is to be expected. 


In Figure 28.18(c), we plot the reconstructed version of the data, which is computed as follows: 


P(End = IEn) = Ji P(Zn|En)Plêna|Zn) (28.119) 


If we threshold these probabilities at 0.5 (corresponding to a MAP estimate), we get the “denoised” 
version of the data in Figure 28.18(d). 


28.3.5.2 Example: categorical PCA 


We can generalize the model in Section 28.3.5.1 to handle categorical data by using the following 
likelihood: 


p(«|z) = | | Cat(xa|softmax(Wuz)) (28.120) 
d 


We call this categorical PCA (CatPCA). A variational EM algorithm for fitting this is described 
in [Kha+10]. 
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28.3.6 Factor analysis with DNN likelihoods (VAEs) 


The FA model assumes the observed data can be modeled as arising from a linear mapping from a 
low-dimensional set of Gaussian factors. One way to relax this assumption is to let the mapping 
from z to æ be a nonlinear model, such as a neural network. That is, the likelihood becomes 


p(x|z) = N (z| f (w; 8), 07D) (28.121) 


We call this “nonlinear factor analysis”. (We can of course replace the Gaussian likelihood with 
other distributions, such as categorical, in which case we get nonlinear exponential family factor 
analysis.) Unfortunately we can no longer compute the posterior or the MLE exactly, so we need to 
use approximate methods. In Chapter 21, we discuss variational autoencoders, which fits this model 


using amortized variational inference. However, it is also possible to fit the same model using other 
inference methods, such as MCMC (see e.g., [Hof17]). 


28.3.7 Factor analysis with GP likelihoods (GP-LVM) 


In this section we discuss a nonlinear version of factor analysis in which we replace the linear decoder 
f(z) = Wz used in the likelihood p(y|z) = N (y| f(z), 071) with a nonlinear function, represented by 
a Gaussian process (Chapter 18), one per output dimension. This is known as a GP-LVM, which 
stands for “Gaussian process latent variable model” [Law05]. (Note that we switch notation a bit 
from standard FA and define the observed output variable by y, to be consistent with standard 
supervised GP notation; the inputs to the GP will be latent variables z.) 

To explain the method in more detail, we start with PPCA (Section 28.3.2). Recall that the PPCA 
model is as follows: 


p(zi) = N (z:|0, I) (28.122) 
plyilzi, 0) = N(yi|W2i, 071) (28.123) 


29 We can fit this model by maximum likelihood, by integrating out the z; and maximizing wrt W (and 


a”). The objective is given by 


p(Y|W, 0?) = (2n)-PN/21C|-N/? exp (-5uet¥Ty)) (28.124) 


— where C = WW! +071. As we showed in Section 28.3.2, the MLE for W can be computed in terms 
~ of the eigenvectors of YTY. 


Now we consider the dual problem, whereby we maximize wrt Z and integrate out W. We will use 
a prior of the form p(W) = [[; M(w;|0, I). The corresponding likelihood becomes 


p(Y|Z, o°) - [Jv .al0, ZZ" + 071) (28.125) 


1 
= (2n)~PX/2|K,|-?/? exp (-5u0k;"¥¥7)) (28.126) 


45 where K, = K + 07I, and K = ZZ". The MLE for Z can be computed in terms of the eigenvectors 
46 of K,, and gives the same results as PPCA (see [Law05] for the details). 
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Figure 28.19: Illustration of a 2d embedding of human motion-capture data using a GP-LVM. We show two 
poses and their corresponding embeddings. Generated by gplum_ mocap.ipynb. Used with kind permission of 
Aditya Ravuri. 


To understand what this process is doing, consider modeling the prior on f : Z > Y with a GP 
with a linear kernel: 


K (zi, zj) = zi Zj + 076; (28.127) 


The corresponding covariance mtarix has the form K = ZZ! + o7I. Thus Equation (28.126) is 
equivalent to the likelihood of a product of independent GPs. Just as factor analysis is like linear 
regression with unknown inputs, so GP-LVM is like GP regression with unknown inputs. The goal 
is then to compute a point estimate of these unknown inputs, i.e., Z. (We can also use Bayesian 
inference.) 

The advantage of the dual formulation is that we can use a more general kernel for K instead of 
the linear kernel. That is, we can set Kj; = K(z;, zj) for any Mercer kernel. The MLE for Z is no 
longer be available via eigenvalue methods, but can be computed using gradient-based optimization. 

In Figure 28.19, we illustrate the model (with an ARD kernel) applied to some motion capture 
data, from the CMU mocap database at http: //mocap.cs.cmu.edu/. Each person has 41 markers, 
whose motion in 3d is tracked using 12 infrared cameras. Each data point corresponds to a different 
body pose. When projected to 2d, we see that similar poses are clustered nearby. 


28.4 LFMs with non-Gaussian priors 


In this section, we discuss (linear) latent factor models with non-Gaussian priors. See Table 28.1 for 
a summary of the models we will discuss. 


28.4.1 Non-negative matrix factorization (NMF) 


Suppose that we use a gamma distribution for the latents: p(z) = [], Ga(zk|@k, 8k). This results in 
a sparse, non-negative hidden representation, which can help interpretability. This is particularly 
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Figure 28.20: (a) Gaussian-Poisson (GAP) model as a DPGM. Here zn, E€ Rt and £n a E€ Z>o. (b) Simpler 
FA model as a DPGM. Here zn € Sx and tn,a € {1,..., V}. 


useful when the data is also sparse and non-negative, such as word counts. In this case, it makes 
sense to use a Poisson likelihood: p(a|z) = Hex Poi(xq|w}z). The overall model has the form 


D 


II Ga(zz|ap, so) T Poi(zaļw} z) 
k 


d=1 


p(z,@) = p(z)p(x|z) = (28.128) 


The resulting model is called the GaP (Gamma-Poisson) model [Can04]. See Figure 28.20a for the 
graphical model. 

The parameters az and 8p control the sparsity of the latent representation zn. If we set a, = Bk = 0, 
and compute the MLE for W, we recover non-negative matrix factorization (NMF) [PT94; 


3o LS99; LSO1], as shown in [BJO06]. 


Figure 28.21 illustrates the result of applying NMF to a dataset of image patches of faces, where 
the data correspond to non-negative pixel intensities. We see that the learned basis functions are 
small localized parts of faces. Also, the coefficient vector z is sparse and positive. For PCA, the 
coefficient vector has negative values, and the resulting basis functions are global, not local. For 
vector quantization (i.e., GMM model), z is a one-hot vector, with a single mixture component 
turned on; the resulting weight vectors correspond to entire image prototypes. The reconstruction 
quality is similar in each case, but the nature of the learned latent representation is quite different. 


= 28.4.2 Multinomial PCA 


43 Suppose we use a Dirichlet prior for the latents, p(z) = Dir(z|a@), so z € Sx, which is the K- 
44 dimensional probability simplex. As in Section 28.4.1, the vector z will be sparse and non-negative, 
45 but in addition it will satsify the constraint D Zk = 1, so the components are not independent. 
46 Now suppose our data is categorical, xq € {1,...,V}, so our likelihood has the form p(æ|z) = 
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Original 


Figure 28.21: Illustrating the difference between Non-negative Matrix Factorization (NMF), Vector Quan- 
tization (VQ), and Principal Components Analysis (PCA). Left column: Filters (columns of W) learned 
from a set of 2429 faces images, each of size 19 x 19. There are 49 basis functions in total, shown in a7 x 7 
montage; each filter is reshaped to a 19 x 19 image for display purposes. (For PCA, negative weights are red, 
positive weights are black.) Middle column: The 49 latent factors z when the model is applied to the original 
face image shown at the top. Right column: reconstructed face image. From Figure 1 of [LS99]. 


[], Cat(va|Waz). The overall model is therefore 


D 
p(z, x) = Dir(z|a) | | Cat(wa|Waz) (28.129) 
d=1 


See Figure 28.20b for the DPGM. This model (or small variants of it) has multiple names: user 
rating profile model [Mar03], admixture model [PSD00], mixed membership model [EFL04], 
multinomial PCA (mPCA) [BJ06], or simplex factor analysis (sFA) [BD11]. 

28.4.2.1 Example: roll call data 

Let us consider the example from [BJ06], who applied this model to analyze some roll call data 


from the US Senate in 2003. Specifically, the data has the form £n, a E€ {+1,—1,0} for n = 1 : 100 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Jo Ie lo IN Ie 


BIN IS ISIS IS IS IS IS le Ie IR le la le le Is IB ls 


918 


= SC pk & x A c $S z pi 
Q Zama [0-5 SOF Ea a Use os ~— so oTa 
mas nn LS assoe-39<0_) -A TQEA27200-7.08 © gee o 
$2196 E E E ESEA EA 
! A bon d = rd agi kz = IÉ T g 
TACT I Paa lla T SSAA SSZ Q sOe>AA sna Tal Sars al>aqa 1 
ae aaa SOSS FAG eS TASS Pim so GAS seh eka oo AS ARES alee agan 
AEREE BEBE me US SCOO SOF SS cba e eae PyCeOL Se ane Erg 
OVS ESV LSS SSNS ESS eS Sy MSR Sop OBB SMES OS STS eS Sseasgvesgge 
SHER SSEES ES SSESS SESS ESE RSS EE SoS EEEE EEEF 
ADAIIYIOAMSUAMARSO< ZORA MOMI SE SAME AZM emo AmmAZzosa 
1 
2 TT | 
3 
4 BEEE 
5 SERRE 
(a) Democrats. 
= oo BR G A z oH 
Oo -ATR = M a a 3 a ZY S 
ba ~ TOR GT GD TAS ~ c QO -a branc Oa 
z8% EE RE E ale E EER 
TE AV TO MS nyu TeL ATAT TAIAT AT NR ana R 
BAA E E A ASETA AEEA Pipa Ue: Tess 
go 288 se eee ee eee eee ee ne per oe Cee ee cus sears tag ae 
PA ELELE EEEE MPSER SSR SCS EES Ze SS er ees Se RO ee SIPS ER 
Heese ENOR- MER R-E T Sa o 0owSoSR SSS fSoSeogsoc&Ss sees 
OF OOS SRSA SS EO EEE SASS RECA BEOA ATO S20 ar BO ets T OEA 
1 | E E BE E E On EEEEE 
2 EE EE SERRE HEE SR ee NN NENTE 
3 BEB 
4 
5 


(b) Republicans. 


Figure 28.22: The simplex factor analysis model applied to some roll call data from the US Senate collected 
in 2008. The senators have been sorted from left to right using the binary PCA method of [Lec06]. See text 
for details. From Figures 8-9 of [BJ06]. Used with kind permission of Wray Buntine. 


~ and d= 1: 459, where zna is the vote of the n’th senator on the d’th bill, where +1 means in favor, 


Se Ble ISIS |S IS IS S iS S e ls Bie 2 |B 
NID lo Te jo [we Ie IO [© |e IN Io Jor [A Iw N e Io Io 


~ -1 means against, and 0 means not voting. In addition, we have the overall outcome, which we denote 
~~ by z101,a E {+1,—1}, where +1 means the bill was passed, and -1 means it was rejected. 


We fit the mPCA model to this data using 5 latent factors using variational EM. Figure 28.22 


~ plots E [zn%|%n] € [0,1], which is the degree to which senator n belongs to latent component or “bloc” 
— k. We see that component 5 is the Democractic majority, and block 2 is the Republican majority. 
~= See [BJ06] for further details. 


28.4.2.2 Advantage of Dirichlet prior over Gaussian prior 


= The main advantage of using a Dirichlet prior compared to a Gaussian prior is that the latent factors 
== are more interpretable. To see this, note that the mean parameters for d’th output distribution have 
* the form H pa = W%z,, and hence 


Pina = v|zn) = X Anns, (28.130) 
k 


45 Thus the latent variables can be additively combined to compute the mean parameters, aiding 
46 interpretability. By contrast, the CatPCA model in Section 28.3.5.2 uses a Gaussian prior, so W@z,, 
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28.5. TOPIC MODELS 


can be negative; consequently it must pass this vector through a softmax, to convert from natural 
parameters to mean parameters; this makes z, harder to interpret. 


28.4.2.3 Connection to mixture models 


If z, were a one-hot vector, rather than any point in the probability simplex, then the mPCA model 
would be equivalent to selecting a single column from Wg corresponding to the discrete hidden state. 
This is equivalent to a finite mixture of categorical distributions (c.f., Section 28.2.2), and corresponds 
to the assumption that x is generated by a single cluster. However, the mPCA model does not require 
that Zn be one-hot, and instead allows x, to partially belong to multiple clusters. For this reason, 
this model is also known as an admixture mixture or mixed membership model [EFL04]. 


28.5 Topic models 


In this section, we show how to modify the multinomial PCA model of Section 28.4.2 to create latent 
variable models for sequences of discrete tokens, such as words in text documents, or genes in a DNA 
sequence. The basic idea is to assume that the words are conditionally independent given a latent 
topic vector z. Rather than being a single discrete cluster label, z is a probability distribution over 
clusters, and each word is sampled from its own “local” cluster. In the NLP community, this kind of 
model is called a topic model (see e.g., [BGHM17]). 


28.5.1 Latent Dirichlet Allocation (LDA) 


In this section, we discuss the most common kind of topic model known as latent Dirichlet 
allocation or LDA [BNJ03a; Ble12]. (This usage of the term “LDA” is not to be confused with 
linear discriminant analysis.) In the genetic community, this model is known as an admixture 


model [PSD0O0]. 


28.5.1.1 Model definition 


We can define the LDA model as follows. Let £n; € {1,...,V} be the identity of the l'th word in 
document n, where l can now range from 1 to Ln, the length of the document, and V is the size of 
the vocabulary. The probability of word v at location / is given by 


k 
where 0 < zn% < 1 is the proportion of “topic” k in document n, and zn ~ Dir(a@). 
We can rewrite this model by associating a discrete latent variable mn € {1,...,Nz} with each 
word in each document, with distribution p(myi|Zn) = Cat(Mni|zn). Thus Mn; specifies the topic to 
use for word / in document n. The full joint model becomes 


Ln 
P(&n; Zn, Mn) = Dir(zZn|a) II Cat(mni|Zn)Cat(@ni|W [Mn :]) (28.132) 
1=1 


where W[k,:] = wz is the distribution over words for the k’th topic. See Figure 28.23 for the 
corresponding DPGM. 
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(a) (b) 


Figure 28.23: Latent Dirichlet Allocation (LDA) as a DPGM. (a) Unrolled form. (b) Plate form. 


We typically use a Dirichlet prior the topic parameters, p(w ,) = Dir(w,|G1y); by setting 8 small 
enough, we can encourage these topics to be sparse, so that each topic only predicts a subset of the 
words. In addition, we use a Dirichlet prior on the latent factors, p(Zn) = Dir(znla1w,). If we set a 
small enough, we can encourage the topic distribution for each document to be sparse, so that each 
document only contains a subset of the topics. See Figure 28.24 for an illustration. 

Note that an earlier version of LDA, known as probabilistic LSA, was proposed in [Hof99]. 
(LSA stands for “latent semantic analysis”, and refers to the application of PCA to text data; see 


27 [Mur22, Sec 20.5.1.2] for details.) The likelihood function, p(a|z), is the same as in LDA, but pLSA 
23 does not specify a prior for z, since it is designed for posterior analysis of a fixed corpus (similar to 


LSA), rather than being a true generative model. 


— 28.5.1.2 Polysemy 


33 Each topic is a distribution over words that co-occur together, and which are therefore semantically 
34 related. For example, Figure 28.25 shows 3 topics which were learned from an LDA model fit to 
35 the TASA corpus. These seem to correspond to 3 different senses of the word “play”: playing an 


instrument, a theatrical play, and playing a sports game. 
We can use the inferred document-level topic distribution to overcome polysemy, i.e., to disam- 


38 biguate the meaning of a particular word. This is illustrated in Figure 28.26, where a subset of the 
39 words are annotated with the topic to which they were assigned (i.e., we show argmax;, p(™Mni = k|£n). 
40 In the first document, the word “music” makes it clear that the musical topic (number 77) is present 
41 in the document, which in turn makes it more likely that m,; = 77 where | is the index corresponding 
42 to the word “play”. 


3. The TASA corpus is an untagged collection of educational materials consisting of 37,651 documents and 12,190,931 


=" word tokens. Words appearing in fewer than 5 documents were replaced with an asterisk, but punctuation was included. 
46 The combined vocabulary was of size 37,202 unique words. 
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Topic proportions and 


Topics Documents assignments 


Seeking Life’s Bare (Genetic) Necessities 


COLD SPRING HARBOR, NEW YORK— 
H. | jt 


organism 


life 
evolve 


organism 


ology Information TNCBI 
in Bethesda, Maryland. Compa 


* Genome Mapping and Sequenc- — 

ing, Cold Spring Harbor, New York Stripping down. Com) 
May 8 to 12. mate of the minimum modern and ancient genomes. 
data 
number 


computer 


2 = 


SCIENCE ¢ VOL. 272 ¢ 24 MAY 199 


Figure 28.24: Illustration of latent Dirichlet allocation (LDA). We have color coded certain words by the topic 
they have been assigned to: yellow represents the genetics cluster, pink represents the evolution cluster, blue 
represent the data analysis cluster, and green represents the neuroscience cluster. Each topic is in turn defined 
as a sparse distribution over words. This article is not related to neuroscience, so no words are assigned to 
the green topic. The overall distribution over topic assignments for this document is shown in the right as a 
sparse histogram. Adapted from Figure 1 of [Ble12]. Used with kind permission of David Blei. 


Topic 77 Topic 82 Topic 166 
word prob. word prob. word _ prob. 
MUSIC .090 LITERATURE -031 PLAY .136 
DANCE _ .034 POEM .028 BALL .129 
SONG .033 POETRY  .027 GAME _ .065 
PLAY .030 POET .020 PLAYING  .042 
SING .026 PLAYS .019 HIT .032 
SINGING — .026 POEMS _ .019 PLAYED .031 
BAND  .026 PLAY .015 BASEBALL — .027 
PLAYED .023 LITERARY _ .013 GAMES .025 
SANG .022 WRITERS _ .013 BAT .019 
SONGS .021 DRAMA __ .012 RUN .019 
DANCING — .020 WROTE .012 THROW .016 
PIANO .017 POETS .011 BALLS .015 
PLAYING — .016 WRITER .011 TENNIS .011 
RHYTHM 015 SHAKESPEARE — .010 HOME .010 
ALBERT  .013 WRITTEN _ .009 CATCH .010 
MUSICAL _.013 STAGE _.009 FIELD _.010 


Figure 28.25: Three topics related to the word play. From Figure 9 of [SG07]. Used with kind permission of 
Tom Griffiths. 
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1 
2 
= Document #29795 
3 age” fifteen”, sat!”* slope” bluff” overlooking” mississippi” river”. le 
mn listening” ‘0 music” coming” passing” music” captured” his heart’ 
4 ear” jazz” music” lessons” showed” promise! 
— piano” parents” hoped*“* consider? concert?” pianist” 
5 interested"? kind™? oí music” wanted" lay” wanted? lay” jazz”... 
6 
7 Document #1883 
= simple” reason'”® periods”? theater”? western” 
8 things*” actors” 
= actors” audiences”? remember?” 
9 plays” exist" performed’ merely” read” read a [play] try™® 
A perform” put’ stage”? soon”? lay” performed”? 
kind” o! theatrical... 
Document #21359 
Jim” game’ book”™ Jim’ reads*™ the book?’ Jim”* sees”! a game'®® Jim?” plays'® the game’ 
Jim?" likes™! ihe game!®® game!“ book” helps jim”®. Don'® comes”? house. Don!*° 
ss 296 254 166 254 020 166 020 16 166 
jim™’ read game ™ book’ boys game boys play game 
boys™ [play™™ game’®® boys” game’. Meg” comes™ house” Meg”? 
don? 1 jim read”™ ihe book” game” Meg”? don” jim” play game’ 
lay”... 
Figure 28.26: Three documents from the TASA corpus containing different senses of the word play. Grayed 
out words were ignored by the model, because they correspond to uninteresting stop words (such as “and”, “the”, 


etc.) or very low frequency words. From Figure 10 of [SG07]. Used with kind permission of Tom Griffiths. 


28.5.1.3 Posterior inference 


Many algorithms have been proposed to perform approximate posterior inference in the LDA model. 
In the original LDA paper, [|BNJ03a], they use variational mean field inference (see Section 10.2). In 
[HBB10], they use stochastic VI (see Supplementary Section 28.1.2). In [GS04], they use collapsed 
Gibbs sampling, which marginalizes out the discrete latents (see Supplementary Section 28.1.1). 
In [MB16; SS17b] they discuss how to learned amortized inference networks to perform VI for the 


IIB IS IS IS IS (IS IS IS le Ie IR ls la le le Is E Ie 


|S 
© 


31 collapsed model. 

32 Recently, there has been considerable interest in spectral methods for fitting LDA-like models 
33 which are fast and which come with provable guarantees about the quality of the solution they obtain 
34 (unlike MCMC and variational methods, where the solution is just an approximation of unknown 
35 quality). These methods make certain (reasonable) assumptions beyond the basic model, such as the 
36 existence of some anchor words, which uniquely the topic for a document. See [Aro+13] for details. 
37 

38 

39 28.5.1.4 Determining the number of topics 


D 
© 


41 Choosing N,, the number of topics, is a standard model selection problem. Here are some approaches 
42 that have been taken: 

43 © Use annealed importance sampling (Section 11.5.4) to approximate the evidence [Wal-+09]. 

44 © Cross validation, using the log likelihood on a test set. 

45 © Use the variational lower bound as a proxy for log p(D|N-). 

46 © Use non-parametric Bayesian methods [Teh+06al. 

47 
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28.5.2 Correlated topic model 


One weakness of LDA is that it cannot capture correlation between topics. For example, if a document 
has the “business” topic, it is reasonable to expect the “finance” topic to co-occcur. The source 
of the problem is the use of a Dirichlet prior for z,. The problem with the Dirichlet it that it is 
characterized by just a mean vector œ, but its covariance is fixed (Xj; = —a;a,;), rather than being a 
free parameter. 

One way around this is to replace the Dirichlet prior with the logistic normal distribution, which 
is defined as follows: 


p(z) = [Cat (zlsottinax(@) Ai (elm X)de (28.133) 


This is known as the correlated topic model [BL07b]; 

The difference from categorical PCA discussed in Section 28.3.5.2 is that CTM uses a logistic 
normal to model the mean parameters, so Z» is sparse and non-negative, whereas CatPCA uses a 
normal to model the natural parameters, so z, is dense and can be negative. More precisely, the 
CTM defines x. ~ Cat(Wsoftmax(e,,)), but CatPCA defines £na ~ Cat(softmax(WaZn)). 

Fitting the CTM model is tricky, since the prior for €n is no longer conjugate to the multinomial 
likelihood for mn ;. However, we can derive a variational mean field approximation, as described in 
[BLO7b]. 

Having fit the model, one can then convert Ý to a sparse precision matrix $7! by pruning 
low-strength edges, to get a sparse Gaussian graphical model. This allows you to visualize the 
correlation between topics. Figure 28.27 shows the result of applying this procedure to articles from 
Science magazine, from 1990-1999. 


28.5.3 Dynamic topic model 


In LDA, the topics (distributions over words) are assumed to be static. In some cases, it makes sense 
to allow these distributions to evolve smoothly over time. For example, an article might use the topic 
“neuroscience”, but if it was written in the 1900s, it is more likely to use words like “nerve”, whereas if 
it was written in the 2000s, it is more likely to use words like “calcium receptor” (this reflects the 
general trend of neuroscience towards molecular biology). 

One way to model this is to assume the topic distributions evolve according to a Gaussian random 
walk, as in a state space model (see Section 29.1). We can map these Gaussian vectors to probabilities 
via the softmax function, resulting in the following model: 


wwe) ~ N (wis, 07 10,,) (28.134) 

zn ~ Dir(a1y,) (28.135) 

mile, ~ Cat(zn) (28.136) 

x |mi, =k, W* ~ Cat(softmax(w7,)) (28.137) 


This is known as a dynamic topic model [BL06]. See Figure 28.28 for the DPGM. 

One can perform approximate inference in this model using a structured mean field method 
(Section 10.4.1), that exploits the Kalman smoothing algorithm (Section 8.3.2) to perform exact 
inference on the linear-Gaussian chain between the wt nodes (see [BL06] for details). Figure 28.29 
illustrates a typical output of the system when applied to 100 years of articles from Science. 
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Figure 28.27: Output of the correlated topic model (with K = 50 topics) when applied to articles from Science. 


Nodes represent topics, with the 5 most probable phrases from each topic shown inside. Font size reflects 
overall prevalence of the topic. See http: // www. cs. cmu. edu/ “Lemur/ science/ for an interactive version 
of this model with 100 topics. Used with kind permission of Figure 2 of [BLO7b]. Used with kind permission 
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of David Blei. 


It is also possible to use amortized inference, and to learn embeddings for each word, which works 
much better with rare words. This is called the dynamic embedded topic model [DRB19]. 


32 28.5.4 LDA-HMM 


34 The Latent Dirichlet Allocation (LDA) model of Section 28.5.1 assumes words are exchangeable, 


and thus ignores word order. A simple way to model sequential dependence between words is to 
use an HMM. The trouble with HMMs is that they can only model short-range dependencies, so 
they cannot capture the overall gist of a document. Hence they can generate syntactically correct 


sentences, but not semantically plausible ones. 


It is possible to combine LDA with HMM to create a model called LDA-HMM [Gri+04]. This 


40 model uses the HMM states to model function or syntactic words, such as “and” or “however”, 
41 the LDA to model content or semantic words, which are harder to predict. There is a distinguished 
42 HMM state which specifies when the LDA model should be used to generate the word; the rest of 
43 the time, the HMM generates the word. 
More formally, for each document n, the model defines an HMM with states hn; € {0,..., 
45 addition, each document has an LDA model associated with it. If hn; = 0, we generate word £ni 
46 from the semantic LDA model, with topic specified by Mn; otherwise we generate word £n; from the 
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Figure 28.28: The dynamic topic model as a DPGM. 


1881 1890 1900 1910 1920 1930 1940 1950 1960 1970 1980 1990 2000 
brain movement brain movement | | movement stimulate record respons response respons cell cell neuron 
movement eye eye brain sound muscle nerve record stimulate cell neuron | | channel active 
action right movement sound muscle sound stimulate stimulate record potential} |response| | neuron brain 
right hand right nerve active movement | | response nerve condition stimul active ca2 cell 
eye brain left active nerve response muscle muscle active neuron brain active fig 
hand left hand muscle stimulate nerve electrode active potential active stimul brain response 
left action nerve left fiber frequency | | active frequency | | stimulus nerve muscle | |receptor| | channel 
muscle muscle vision eye reaction fiber brain electrode nerve eye system muscle receptor 
nerve sound sound right brain active fiber potential subject record nerve respons synapse 
sound experiment muscle nervous response brain potential study eye abstract receptor) |_current signal 
1887 Mental Science 
W 1900 Hemianopsia in Migraine 
1912 A Defence of the `New Phrenology" 
1921 The Synchronal Flashing of Fireflies 
"Neuroscience" 1932 Myoesthesis and Imageless Thought 
1943 Acetylcholine and the Physiology of the Nervous System 
1952 Brain Waves and Unit Discharge in Cerebral Cortex 
caz 1963 Errorless Discrimination Learning in the Pigeon 
PREA 1974 Temporal Summation of Light by a Vertebrate Visual Receptor 
1983 Hysteresis in the Force-Calcium Relation in Muscle 
1993 GABA-Activated Chloride Channels in Secretory Nerve Endings 


T 
1880 


T T T T 
14900 1920 1940 1960 1980 2000 


Figure 28.29: Part of the output of the dynamic topic model when applied to articles from Science. At the top, 
we show the top 10 words for the neuroscience topic over time. On the bottom left, we show the probability of 
three words within this topic over time. On the bottom right, we list paper titles from different years that 
contained this topic. From Figure 4 of [BL06]. Used with kind permission of David Blei. 


syntactic HMM model. The DPGM is shown in Figure 28.30. The CPDs are as follows: 


P(Zn) 
P(Mnt = k|zn) 
p(An,t = jlhn -1 = i) 


d|mni k, hni I 


P(Lnt 


Dir(z,|a1y, ) 


nk 

Aij 
Wea ifj=0 
Bia ifj>0 


(28.138) 
(28.139) 
(28.140) 


(28.141) 


where W is the usual topic-word matrix, B is the state-word HMM emission matrix and A is the 
state-state HMM transition matrix. 

Inference in this model can be done with collapsed Gibbs sampling, analytically integrating out all 
the continuous quantities. See [Gri+04] for the details. 
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Figure 28.380: LDA-HMM model as a DPGM. 
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= Table 28.2: Upper row: Topics extracted by the LDA model when trained on the combined Brown and TASA 
= corpora. Middle row: topics extracted by LDA part of LDA-HMM model. Bottom row: topics extracted by 
42 HMM part of LDA-HMM model. Each column represents a single topic/class, and words appear in order of 
43 probability in that topic/class. Since some classes give almost all probability to only a few words, a list is 
terminated when the words account for 90% of the probability mass. From Figure 2 of [Gri+04]. Used with 
kind permission of Tom Griffiths. 
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network activity single cell 
input resistance time space constants 
i excitability spariotemporal (sic) integration 
feed forward | control error feedback adaptive 
control neural networks 
proof ol convergence softassign algorithm doubly 
stochastic matrix matrix 
2 doubly stochastic metric 
portfolio expected | return risk level time 
horizon *institutional 
training # samples 
3. 
graph |G žguest| graph H host | graph 


Figure 28.81: Function and content words in the NIPS corpus, as distinguished by the LDA-HMM model. 
Graylevel indicates posterior probability of assignment to LDA component, with black being highest. The boxed 
word appears as a function word in one sentence, and as a content word in another sentence. Asterisked 
words had low frequency, and were treated as a single word type by the model. From Figure 4 of [Gri+O4]. 
Used with kind permission of Tom Griffiths. 


The results of applying this model (with N, = 200 LDA topics and H = 20 HMM states) to the 
combined Brown and TASA corpora are shown in Table 28.2. We see that the HMM generally is 
responsible for syntactic words, and the LDA for semantics words. If we did not have the HMM, the 
LDA topics would get “polluted” by function words (see top of figure), which is why such words are 
normally removed during preprocessing. 

The model can also help disambiguate when the same word is being used syntactically or semanti- 
cally. Figure 28.31 shows some examples when the model was applied to the NIPS corpus.’ We see 
that the roles of words are distinguished, e.g., “we require the algorithm to return a matrix” (verb) vs 
“the maximal expected return” (noun). In principle, a part of speech tagger could disambiguate these 
two uses, but note that (1) the LDA-HMM method is fully unsupervised (no POS tags were used), 
and (2) sometimes a word can have the same POS tag, but different senses, e.g., “the left graph” (a 
synactic role) vs “the graph G” (a semantic role). 

More recently, [Die+17] proposed topic-RNN, which is similar to LDA-HMM, but replaces the 
HMM model with an RNN, which is a much more powerful model. 


4. The Brown corpus consists of 500 documents and 1,137,466 word tokens, with part-of-speech tags for each token. 
The TASA corpus is an untagged collection of educational materials consisting of 37,651 documents and 12,190,931 
word tokens. Words appearing in fewer than 5 documents were replaced with an asterisk, but punctuation was included. 
The combined vocabulary was of size 37,202 unique words. 

5. NIPS stands for “Neural Information Processing Systems”. It is one of the top machine learning conferences. The 
NIPS corpus volumes 1-12 contains 1713 documents. 
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Figure 28.32: Illustration of ICA applied to 500 iid samples of a 4d source signal. This matches the true 
sources, up to permutation of the dimension indices. Generated by ica_ demo.ipynb. 


27 28.6 Independent components analysis (ICA) 


Consider the following situation. You are in a crowded room and many people are speaking. Your 
ears essentially act as two microphones, which are listening to a linear combination of the different 


31 speech signals in the room. Your goal is to deconvolve the mixed signals into their constituent 


parts. This is known as the cocktail party problem, or the blind source separation (BSS) 


33 problem, where “blind” means we know “nothing” about the source of the signals. Besides the obvious 


applications to acoustic signal processing, this problem also arises when analysing EEG and MEG 


35 signals, financial data, and any other dataset (not necessarily temporal) where latent sources or 
36 factors get mixed together in a linear way. See Figure 28.32 for an example. 


38 28.6.1 Noiseless ICA model 


We can formalize the problem as follows. Let x, € R? be the vector of observed responses, at “time” 
n, where D is the number of sensors / microphones. Let z, € R? be the hidden vector of source 


yo Signals at time n, of the same dimensionality as the observed signal. We assume that 


Ln = Aza (28.142) 


45 where A is an invertible D x D matrix known as the mixing matrix or the generative weights. 
46 The prior has the form p(z,) = Tha p(z). Typically we assume this is a sparse prior, so only a 
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28.6. INDEPENDENT COMPONENTS ANALYSIS (ICA) 


—2 0 2 —2 0 2 


(a) Uniform data (b) Uniform data after linear mixing 


(c) PCA estimate (d) ICA estimate 


Figure 28.33: Illustration of ICA and PCA applied to 100 tid samples of a 2d source signal with a uniform 
distribution. Generated by ica_demo_ uniform.ipynb. 


subset of the signals are active at any one time (see Section 28.6.2 for further discussion of priors for 
this model). This model is called independent components analysis or ICA, since we assume 
that each observation a, is a linear combination of independent components represented by sources 
Zn, Le, 


Enj = X Aizeng (28.143) 
Our goal is to infer the source signals, p(Zn|£n, A). Since the model is noiseless, we have 


PlZn|En, A) = b(Zn a Bz,,) (28.144) 


where B = A`! are the recognition weights. (We discuss how to estimate these weights in 
Section 28.6.3.) 


28.6.2 The need for non-Gaussian priors 


Since = Az, we have E [a] = AE |z] and Cov [x] = Cov [Az] = ACov[z] AT. Without loss of 
generality, we can assume E [z] = 0, since we can always center the data. Similarly, we can assume 
Cov [z] = I, since AA" can capture any correlation in æ. Thus z is a set of D unit variance, 
uncorrelated variables, as in factor analysis (Section 28.3.1). 
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However, this is not sufficient to uniquely identify A and hence z, as we explained in Section 28.3.1.6. 
So we need to go beyond an uncorrelated prior and enforce an independent, and non-Gaussian, prior. 

To illustrate this, suppose we have two independent sources with uniform distributions, as shown 
in Figure 28.33(a). Now suppose we have the following mixing matrix 


2 3 
A=0.3 € J (28.145) 


Then we observe the data shown in Figure 28.33(b) (assuming no noise). The full-rank PCA model 
(where K = D) is equivalent to ICA, except it uses a factored Gaussian prior for z. The result of 
using PCA is shown in Figure 28.33(c). This corresponds to a whitening or sphering of the data, 
in which Cov |z] = I. To uniquely recover the sources, we need to perform an additional rotation. 
The trouble is, there is no information in the symmetric Gaussian posterior to tell us which angle to 
rotate by. In a sense, PCA solves “half” of the problem, since it identifies the linear subspace; all 
that ICA has to do is then to identify the appropriate rotation. To do this, ICA uses an independent, 
but non-Gaussian, prior. The result is shown in Figure 28.33(d). This shows that ICA can recover 
the source variables, up to a permutation of the indices and possible sign change. 

We typically use a prior which is a super-Gaussian distribution, meaning it has heavy tails; this 
helps with identifiability. One option is to use a Laplace prior. For mean zero and variance 1, this 
has a log pdf given by 


log p(z) = —V2|z| — log( v2) (28.146) 


However, since the Laplace prior is not differentiable at the origin, in ICA it is more common to use 
the logistic distribution, discussed in Section 15.4.1. The corresponding log pdf, for the case where 
the mean is zero and the variance is 1, is given by the following: 


log p(z) = —2 log cosh( 373” log z (28.147) 


28.6.3 Maximum likelihood estimation 


31 Since x = Az, from the change of variables formula, the density of the observed data is given by 


pz(£) = p.(z)| det(A~)| = p.(Ba)| det(B))| (28.148) 


34 where B = A~!. To maximize this, we can simplify the problem as follows. Let © = Cov [a], and 
= define some nonsingular matrix V such that B = A~! = V=~2. Then 


z=Br=VE-?22 (28.149) 


Z Since we assumed Cov |z] = I, we have 


Cov [z] = V£- 25572 VT = VVT =I (28.150) 


Hence V is orthogonal. So if we sphere the data, by computing £ = Dgr, we will have B = V. 
So now our goal simplifies to estimating an orthogonal V from the whitened data.’ 


6. Traditionally in the ICA literature the stated goal is to estimate the orthogonal matrix W, but this notation 


22 conflicts with our use of W as generative weights in the factor analysis model of Section 28.3.1. So we use the letter V 
46 instead. 
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28.6. INDEPENDENT COMPONENTS ANALYSIS (ICA) 


With this new notation, we can write the likelihood as 
Px (x) = p:(Vax)| det(V)| (28.151) 


Thus the average negative log likelihood is given by 


L Np 
1 
NLL(V) = T logp(X|V) = — log | det(V)| — — = 5 log p;( v) Tn) (28.152) 


D j=1n=1 


where vy is the j’th row of V, and the prior is factord, so p(z) = |], p;(z;). Since we are constraining 
V to be orthogonal, the log|det(V)| term is a constant, so we can drop it. We can also replace the 
sum over n with an expectation wrt the empirical distribution to get the following objective 


NLL(V) = X` E[G;(z;)] (28.153) 


j 


where z; = viz and G;(z;) = —logp;(z;). We want to minimize this (nonconvex) objective subject 
to the constraint that V is an orthogonal matrix. 

It is straightforward to derive a (projected) gradient descent algorithm to fit this model. (For some 
JAX code, see https: //github.com/tuananhle7/ica). One can also derive a faster algorithm that 
follows the natural gradient; see e.g., [Mac03, ch 34] for details. However, the most popular method 
is to use an approximate Newton method, known as fast ICA [HO00]. This was used to produce 
Figure 28.32. 


28.6.4 Alternatives to MLE 


In this section, we discuss various alternatives estimators for ICA that have been proposed over the 
years. We will show that they are equivalent to MLE. However, they bring interesting perspectives 
to the problem. 


28.6.4.1 Maximizing non-Gaussianity 


An early approach to ICA was to find a matrix V such that the distribution z = Vz is as far from 
Gaussian as possible. (There is a related approach in statistics called projection pursuit [FT74].) 
One measure of non-Gaussianity is kurtosis, but this can be sensitive to outliers. Another measure is 
the negentropy, defined as 


negentropy(z) = H (N (u, 0°)) — H (2) (28.154) 


where u = E [z] and o? = Y [z]. Since the Gaussian is the maximum entropy distribution (for a fixed 
variance), this measure is always non-negative and becomes large for distributions that are highly 
non-Gaussian. 

We can define our objective as maximizing 


J(V = X negentropy( zj) S N (uj o3 *)) — H(z;) (28.155) 
j 
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where z = Va. Since we assume E [z] = 0 and Cov |z] = I, the first term is a constant. Hence 


J(V) = 5 — H (zj) + const = 5 ù [log p(z;)] + const (28.156) 


which we see is equal (up to a sign change, and irrelevant constants) to the log-likelihood in 


Equation (28.153). 


28.6.4.2 Minimizing total correlation 


In Section 5.3.5.1, we show that the total correlation of z is given by 


TC(z) = JH G) - H(z) = Dra | p) I TI aC) (28.157) 


This is zero iff the components of z are all mutually independent. In Section 21.3.1.1, we show that 
minimizing this results in a representation that is disentangled. 
Now since z = Va, we have 


TC(z) = X H(z) — H(Vz) (28.158) 
J 
Since we constrain V to be orthogonal, we can drop the last term, since H(Va) = H(a) = const, since 


multiplying by V does not change the shape of the distribution. Hence we have TC(z) = $., H(z). 
Minimizing this is equivalent to maximizing the negentropy, which is equivalent to maximum 


28 likelihood. 


= 28.6.4.3 Maximizing mutual information (InfoMax) 


32 Let zj; = (uv; x) + € be the noisy output of an encoder, where ¢ is some nonlinear scalar function, 
33 and €e ~ N (0,1). It seems reasonable to try to maximize the information flow through this system, a 
34 principle known as infomax [Lin88b; BS95aj. That is, we want to maximize the mutual information 
35 between z (the internal neural representation) and a (the observed input signal). We have I(x; z) = 
36 H(z) — H(z|a), where the latter term is constant if we assume the noise has constant variance. One 
37 can show that we can approximate the former term as follows 


H(z) = X.E [log ¢'(v}æ)] + log | det(V)| (28.159) 


where, as usual, we can drop the last term if V is orthogonal. If we define ¢(z) to be a cdf, then 


43 ġ'(z) is its pdf, and the above expression is equivalent to the log likelihood. In particular, if we 


use a logistic nonlinearity, ¢(z) = a(z), then the corresponding pdf is the logistic distribution, and 


45 log ¢'(z) = logcosh(z), which matches Equation (28.147) (ignoring irrelevant constants). Thus we 
46 see that infomax is equivalent to maximum likelihood. 
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28.6. INDEPENDENT COMPONENTS ANALYSIS (ICA) 


28.6.5 Sparse coding 


In this section, we consider an extension of ICA to the case where we allow for observation noise 
(using a Gaussian likelihood), and we allow for a non-square mixing matrix W. We also use a Laplace 
prior for z. The resulting model is as follows: 


p(z,@) = p(z)p(a|z) = || | Laplace(z,|0,1/d)| N(@|Wz, 071) (28.160) 
k 


Thus each observation x is approximated by a sparse combination of columns of W, known as basis 
functions; the sparse vector of weights is given by z. (This can be thought of as a form of sparse 
factor analysis, except the sparsity is in the latent code z, not the weight matrix W.) 

Not all basis functions will be active for any given observation, due to the sparsity penalty. 
Hence we can allow for more latent factors K than observations D. This is called overcomplete 
representation. 

If we have a batch of N examples, stored in the rows of X, the negative log joint becomes 


No 
1 
-logp(X, Z|W) = 55 > |latn — W2nl|3 + Allzn|lı + const (28.161) 
n=1 
1 
= zllX - W2I|7 + Al|Z||1,1 + const (28.162) 


The MAP inference problem consists of estimating Z for a fixed W; this is known as sparse coding, 
and can be solved using standard algorithms for sparse linear regression (see Section 15.2.5).7, 

The learning problem consists of estimating W, marginalizing out Z. This is called dictionary 
learning. Since this is computationally difficult, it is common to jointly optimize W and Z (thus 
“maxing out” wrt Z instead of marginalizing it out). We can do this by applying alternating 
optimization to Equation (28.162): estimating Z given W is a sparse linear regression problem, and 
estimating W given Z is a simple least squares problem. (For faster algorithms, see [Mai+10].) 

Figure 28.34(a) illustrates the results of dictionary learning when applied to a dataset of natural 
image patches. (Each patch is first centered and normalized to unit norm.) We see that the method 
has learned bar and edge detectors that are similar to the simple cells in the primary visual cortex 
of the mammalian brain [OF96]. By contrast, PCA results in sinusoidal gratings, as shown in 
Figure 28.34(b).8 


28.6.6 Nonlinear ICA 


There are various ways to extend ICA to the nonlinear case. The resulting methods are similar to 
variational autoencoders (Chapter 21). For details, see e.g., [KKH20]. 


7. Solving an £1; optimization problem for each data example can be slow. However, it is possible to train a neural 
network to approximate the outcome of this process; this is known as predictive sparse decomposition [KRL08; 
GL10]. 

8. The reason PCA discovers sinusoidal grating patterns is because it is trying to model the covariance of the data, which, 
in the case of image patches, is translation invariant. This means Cov [I (x, y), [(2’,y’)] = f [(x — 2’)? + (y — y')?] for 
some function f, where I(x, y) is the image intensity at location (x,y). One can show (see e.g., [HHH09, p125]) that 
the eigenvectors of a matrix of this kind are always sinusoids of different phases, i.e., PCA discovers a Fourier basis. 
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(a) 


Figure 28.34: Illustration of the filters learned by various methods when applied to natural image patches. (a) 
30 Sparse coding. (b) PCA. Generated by sparse_ dict_ demo.ipynb. 


ay! 
Wi 


(b) 
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29 State-space models 


29.1 Introduction 


A state-space model (SSM) is a partially observed Markov model, in which the hidden state, 
Z+, evolves over time according to a Markov process (Section 2.6), and each hidden state generates 
some observations y+ at each time step. (We focus on discrete time systems.) The main goal is to 
infer the hidden states given the observations. However, we may also be interested in using the model 
to predict future observations (e.g., for time-series forecasting). 

An SSM can be represented as a stochastic discrete time nonlinear dynamical system of the form 


Zt = f (2-1, Ur, qt) (29.1) 
Yt = (21, Ut, Yrt-1, Tt) (29.2) 


where z; € R= are the hidden states, u, € RN! are optional observed inputs, y, € R» are observed 
outputs, f is the transition function, q: is the process noise, h is the observation function, 
and r is the observation noise. 

Rather than writing this as a deterministic function of random noise, we can represent it as a 
probabilistic model as follows: 


P(Z|21-1, Ut) = p(2e| f (Zt-1, w)) (29.3) 
P(Yt|Zt, Ut, Y1:t—-1) = p(yr|h(Zt, Ue, Y1-1)) (29.4) 


where p(z;|Z:_1, Uz) is the transition model, and p(y;|Z;, Ut, Y1z-1) is the observation model. 
Unrolling over time, we get the following joint distribution: 


T T 
PLT, Zu:7\ULT) = ety T] pezu) [polzun vio (29.5) 
t=2 t=1 

If we assume the current observation y; only depends on the current hidden state, z+, and the 
previous observation, y;_1, we get the graphical model in Figure 29.1(a). (This is called an auto- 
regressive state-space model.) However, by using a sufficient expressive hidden state z, we can 
implicitly represent all the past observations, y;.4-;. Thus it is more common to assume that the 
observations are conditionally independent of each other (rather than having Markovian dependencies) 
given the hidden state. In this case the joint simplifies to 


T T 
P(YLT, 21:7\UL:7) = ety T] réz) T roesuo) (29.6) 


t=2 
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Ut-1 Ut 
Zt-1 Zt 
Zt-1 Zt 
Yt-1 Ut Yi-1 Yt 
(a) (b) 


Figure 29.1: State-space model represented as a graphical model. (a) Generic form, with inputs uz, hidden 
state z+, and observations y+. We assume the observation likelihood is first-order auto-regressive. (b) Simplified 
form, with no inputs, and Markovian observations. 


Sometimes there are no external inputs, so the model further simplifies to the following unconditional 
generative model: 


T 


T 
P(Y1:7T, 21:7) = ma) [ol Troe) (29.7) 


t=2 


See Figure 29.1(b) for the simplified graphical model. 


ə 29.2 Hidden Markov models (HMMs) 


~ In this section, we discuss the hidden Markov model or HMM, which is an SSM in which the 


hidden states are discrete, so z; € {1,..., K}. The observations may be discrete, ys € {1,..., Ny}, 
or continuous, Yyy € R”, or some combination, as we illustrate below. More details on HMMs can be 


,, found in Supplementary Chapter 29, as well as other references, such as [Rab89; Fra08; CMRO5]. For 


an interactive introduction, see https: //nipunbatra. github.io/hmm/. 


29.2.1 Conditional independence properties 


— The HMM graphical model is shown in Figure 29.1(b). This encodes the assumption that the hidden 
— states are Markovian, and the observations are iid conditioned on the hidden states. All that remains 
= is to specify the form of the conditional probability distributions of each node. 


29.2.2 State transition model 


== The initial state distribution is denoted by 


pa =j) =T; (29.8) 


46 where 7 is a discrete distribution over the K states. 
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29.2. HIDDEN MARKOV MODELS (HMMS) 


: ul al i "y 


observable 
Aa 


jais | if ivi i rie Ht fut Rat i 


4 
o 
8 2 
ta 

0 

0 100 200 300 400 
time 

Figure 29.2: Some samples from an HMM with 10 Bernoulli observables. Generated by 


bernoulli_ hmm_ example.ipynb. 


The transition model is denoted by 
pla = Jl%—-1 = i) = Aij (29.9) 


Here the i'th row of A corresponds to the outgoing distribution from state 7. This is a row stochastic 
matrix, meaning each row sums to one. We can visualize the non-zero entries in the transition 
matrix by creating a state transition diagram, as shown in Figure 2.15. 


29.2.3 Discrete likelihoods 


The observation model p(y;|z; = j) can take multiple forms, depending on the type of data. For 
discrete observations we can use 


P(ye = k|z: = j) = Yjk (29.10) 


For example, see the casino HMM example in Section 8.2.1. 
If we have D discrete observations per time step, we can use a factorial model of the form 


D 
P(yilze = j) = | | Cat(yalya,,:) (29.11) 
d=1 


In the special case of binary observations, this becomes 


D 


P(yelze = j) = | | Ber(yealya,s) (29.12) 
d=1 


In Figure 29.2, we give an example of an HMM with 5 hidden states and 10 Bernoulli observables. 
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Generating HMM Simulated data from an HMM 


(a) (b) 


Figure 29.8: (a) Some 2d data sampled from a 5 state HMM. Each state emits from a 2d Gaussian. (b) The 
hidden state sequence is shown by the colors. We superimpose the observed 2d timeseries (note that we have 
shifted the vertical scale so the values don’t overlap). Generated by gaussian _hmm_ 2d.ipynb. 


29.2.4 Gaussian likelihoods 


If y is continuous, it is common to use a Gaussian observation model: 


P(yelze = j) =N (yilu; £3) (29.13) 


As a simple example, suppose we have an HMM with 3 hidden states, each of which generates a 2d 
Gaussian. We can represent these Gaussian distributions as 2d ellipses, as shown in Figure 29.3(a). 
We call these “lily pads”, because of their shape. We can imagine a frog hopping from one lily 
pad to another. (This analogy is due to the late Sam Roweis.) It will stay on a pad for a while 
(corresponding to remaining in the same discrete state 2+), and then jump to a new pad (corresponding 
to a transition to a new state). See Figure 29.3(b). The data we see are just the 2d points (e.g., 
water droplets) coming from near the pad that the frog is currently on. Thus this model is like a 
Gaussian mixture model (Section 28.2.1), in that it generates clusters of observations, except now 
there is temporal correlation between the data points. 

We can also use more flexible observation models. For example, if we use a M-component GMM, 
then we have 


M 
P(yel2e = j) = X wN (yl teins Ezr) (29.14) 
k=1 


This is called a GMM-HMM. 


29.2.5 Autoregressive likelihoods 


= The standard HMM assumes the observations are conditionally independent given the hidden state. 


In practice this is often not the case. However, it is straightforward to have direct arcs from y;_1 to 


= y as well as from z to yz, as in Figure 29.1(a). This is known as an auto-regressive HMM. 


For continuous data, we can use an observation model of the form 
Pyl Y-i, 2 = J, 8) = N (ye |Ejpyr—-1 + Hj, Dy) (29.15) 
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Figure 29.4: Illustration of the observation dynamics for each of the 5 hidden states. The attractor point corre- 
sponds to the steady state solution for the corresponding autoregressive process. Generated by hmm _ ar.ipynb. 


This is a linear regression model, where the parameters are chosen according to the current hidden 
state. (We could also use a nonlinear model, such as a neural network.) Such models are widely 
used in econometrics, where they are called regime switching Markov model [Ham90]. Similar 
models can be defined for discrete observations (see e.g. [SJ99]). 

We can also consider higher-order extensions, where we condition on the last L observations: 


L 


P(YelYe—L:t-1, % = 5,8) =N (y| > Wj eye + Hj, 25) (29.16) 
(=1 


The AR-HMM essentially combines two Markov chains, one on the hidden variables, to capture long 
range dependencies, and one on the observed variables, to capture short range dependencies [Ber99]. 
Since all the visible nodes are observed, adding connections between them just changes the likelihood, 
but does not complicate the task of posterior inference (see Section 8.2.4). 

Let us now consider a 2d example of this, due to Scott Linderman. We use a left-to-right transition 
matrix with 5 states. In addition, the final state returns to first state, so we just cycle through the 
states. Let y; € R*, and suppose we set E; to a rotation matrix with a small angle of 7 degrees, and 
we set each yz; to 72-degree separated points on a circle about the origin, so each state rotates 1/5 
of the way around the circle. If the model stays in the same state j for a long time, the observed 
dynamics will converge to the steady state Yx, j, which satisfies y+ j = Ejys,; + Hj; we can solve for 
the steady state vector using Yx j = (I— E;)~! Hj. We can visualize the induced 2d flow for each of 
the 5 states as shown in Figure 29.4. 

In Figure 29.5(a), we show a trajectory sampled from this model. We see that the two components 
of the observation vector undergo different dynamics, depending on the underlying hidden state. In 
Figure 29.5(b), we show the same data in a 2d scatter plot. The first observation is the yellow dot 
(from state 2) at (—0.8,0.5). The dynamics converge to the stationary value of y,.2 = (—2.0, 3.8). 
Then the system jumps to the green state (state 3), so it adds an offset of u, to the last observation, 
and then converges to the stationary value of yx 3 = (—4.3, —0.8). And so on. 


29.2.6 Neural network likelihoods 


For higher dimensional data, such as images, it can be useful to use a normalizing flow (Chapter 23), 
one per latent state (see e.g., [HNBK18; Gho+21]), as the class-condtional generative model. However, 
it is also possible to use discriminative neural network classifiers, which are much easier to train. In 
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Figure 29.5: Samples from the 2d AR-HMM. (a) Time series plot of y+ı and yt,2. (The latter are shifted up 
vertically by 4.7) The background color is the generating state. The dotted lines represent the stationary value 
for that component of the observation. (b) Scatter plot of observations. Colors denote the generating state. 
We show the first 12 samples from each state. Generated by hmm_ ar.ipynb. 


particular, note that the likelihood per state can be rewritten as follows: 


p(t = jy) p(y) E pla = Jlyt) 
plz = j) plz = j) 


P(yilze = j) = (29.17) 


where we have dropped the p(y;) term since it is independent of the state z,. Here p(z; = j|ys) is 
the output of a classifier, and p(z; = j) is the probability of being in state j, which can be computed 
from the stationary distribution of the Markov chain (or empirically, if the state sequence is known). 


31 We can thus use discriminative classifiers to define the likelihood function when using gradient-based 


training. This is called the scaled likelihood trick [BM93; Ren+94]. [Guo+14] used this to create 


33 a hybrid CNN-HMM model for estimatng sequences of digits based on street signs. 


= 29.3 HMMs: Applications 


37 Tn this section, we discuss some applications of HMMs. 


= 29.3.1 Time series segmentation 


41 In this section, we give a variant of the casino example from Section 8.2.1, where our goal is to 
42 segment a time series into different regimes, each of which corresponds to a different statistical 
43 distribution. In Figure 29.6a we show the data, corresponding to counts generated from some process 
44 (e.g., visits to a web site, or number of infections). We see that the count rate seems to be roughly 
45 constant for a while, and then changes at certain points. We would like to segment this data stream 
46 into K different regimes or states, each of which is associated with a Poisson observation model with 
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29.3. HMMS: APPLICATIONS 


(a) 


Inferred latent rate over time 


latent rate 


(b) 


Figure 29.6: (a) A sample time series dataset of counts. (b) A segmentation of this data using a 4 state 
HMM. Generated by poisson hmm _ changepoint.ipynb. 


latent rate 


latent rate 


Figure 29.7: Segmentation of the time series using HMMs with 1-6 states. 
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son_hmm_ changepoint.ipynb. 


rate Az: 


P(yelze = k) = Poi(yt|Ax) 


Generated by pois- 


(29.18) 


We use a uniform prior over the initial states. For the transition matrix, we assume the Markov 
chain stays in the same state with probability p = 0.95, and otherwise transitions to one of the other 
K — 1 states uniformly at random: 


Cat ical 1111 
~ Categorica SE 
Zi g 4483’ a’ 4 
if = _ 
24|Zt-1 ~ Categorical ({ a UZ = 24-1 
4-1 
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Model selection on latent states 
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Figure 29.8: Marginal likelihood vs number of states K in the Poisson HMM. Generated by pois- 
son_hmm_ changepoint.ipynb. 


We compute a MAP estimate for the parameters A1:« using a log-Normal(5,5) prior. We optimize 
the log of the Poisson rates using gradient descent, initializing the parameters at a random value 
centered on the log of the overall count means. We show the results in Figure 29.6b. See that the 
method has successfully partitioned the data into 4 regimes, which is in fact how it was generated. 
(The generating rates are A = (40,3, 20,50), with the changepoints happening at times (10, 30, 35).) 

In general we don’t know the optimal number of states K. To solve this, we can fit many different 
models, as shown in Figure 29.7, for K = 1:6. We see that after K > 3, the model fits are very 
similar, since multiple states get associated to the same regime. We can pick the “best” K to be the 
one with the highest marginal likelihood. Rather than summing over both discrete latent states and 
integrating over the unknown parameters A, we just maximize over the parameters (empirical Bayes 


a5) approximation): 


P(yir|K) © max) plur, zirlà, K) (29.21) 


38 We show this plot in Figure 29.8. We see the peak is at K = 3 or K = 4; after that it starts to go 
2° down, due to the Bayesian Occam’s razor effect. 


29.3.2 Protein sequence alignment 


41 An important application of HMMs is to the problem of protein sequence alignment [Dur-+98]. 


Here the goal is to determine if a test sequence yı:r belongs to a protein family or not, and if so, 


43 how it aligns with the canonical representation of that family. (Similar methods can be used to align 


DNA and RNA sequences.) 
To solve the alignment problem, let us initially assume we have a set of aligned sequences from 


46 a protein family, from which we can generate a consensus sequence. This defines a probability 
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0 1 2 3 4 


Figure 29.9: State transition diagram for a profile HMM. From Figure 5.7 of [Dur+98]. Used with kind 
permission of Richard Durbin. 
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Figure 29.10: Example of multiple sequence alignment. We show the first 90 positions of the acidic ribosomal 
protein PO from several organisms. Colors represent functional properties of the corresponding amino acid. 
Dashes represent insertions or deletions. From https: // en. wikipedia. org/wiki/Multiple_ sequence_ 
alignment. Used with kind permission of Wikipedia author Miguel Andrade. 


distribution over symbols at each location t in the string; denote each position-specific scoring 
matrix (PSSM) by 6:(v) = p(y: = v). These parameters can be estimated by counting. 

Now we turn the PSSM into an HMM with 3 hidden states, representing the events that the 
location t matches the consensus sequence, z; = M, or inserts its own unique symbol, z = J, or 
deletes (skips) the corresponding consensus symbol, zt = D. We define the observation models for 
these 3 events as follows. For matches, we use the PSSM p(y = v|z, = M) = (v). For insertions 
we use the uniform distribution p(y, = v|z, = I) = 1/V, where V is the size of the vocabulary. For 
deletions, we use p(y; = —|z; = D), where “-” is a special deletion symbol used to pad the generated 
sequence to the correct length. The corresponding state transition matrix is shown in Figure 29.9: 
we see that matches and deletions advance one location along the consensus sequence, but insertions 
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stay in the same location (represented by the self-transition from I to 7). This model is known as a 
profile HMM. 

Given a profile HMM with consensus parameters 0, we can compute p(y1.7|9) in O(T) time using 
the forwards algorithm, as described in Section 8.2.2. This can be used to decide if the sequence 
belongs to this family or not, by thresholding the log-odds score, L(y) = log p(y|@)/p(y|Mo), where 
Mo is a baseline model, such as the uniform distribution. If the string matches, we can compute an 
alignment to the consensus using the Viterbi algorithm, as described in Section 8.2.7. See Figure 29.10 
for an illustration of such a multiple sequence alignment. If we don’t have an initial set of aligned 
sequences from which to compute the consensus sequence 0, we can use the Baum-Welch algorithm 
(Section 29.4.1) to compute the MLE for the parameters @ from a set of unaligned sequences. For 
details, see e.g., [Dur+98, Ch.6]. 


29.3.3 Spelling correction 


In this section, we illustrate how to use an HMM for spelling correction. The goal is to infer the 
sequence of words z1.7 that the user meant to type, given observations of what they actually did 


type, YLT. 


29.3.3.1 Baseline model 


We start by using a simple unigram language model, so p(21:r) = [],.7 p(zt), where p(z = k) is the 
prior probability of word k being used. These probabilities can be estimated by simply normalizing 
word frequency counts from a large training corpus. We ignore any Markov structure. 

Now we turn to the observation model, p(y: = v|z, = k), which is the probability the user types 
word v when they meant to type word k. For this, we use a noisy channel model, in which the 
“message” z; gets corrupted by one of four kinds of error: substitution error, where we swap one 
letter for another (e.g., “government” mistyped as “govermment”); transposition errors, where we 
swap the order of two adjacent letters (e.g., “government” mistyped as “govermnent”); deletion errors, 
where we omit one letter (e.g., “government” mistyped as “goverment”); and insertion errors, where 
we add an extra latter (e.g., “government” mistyped as “governmennt”). If y differs from z by d such 
errors, we say that y and z have an edit distance of d. Let D(y,d) be the set of words that are edit 


distance d away from y. We can then define the following likelihood function: 
Pı yre 
p € D(z,1 
pal = 47 VEDED (29.22) 
P3 yE D(z, 2) 


p4 otherwise 


40 where pı > p2 > p3 > pa. 


We can combine the likelihood with the prior to get the overall score for each hypothesis (i.e., 


42 candidate correction). This simple model, which was proposed by Peter Norvig!, can work can quite 
43 well. However, it also has some flaws. For example, the error model assumes that the smaller the edit 


distance, the more likely the word, but this is not always valid. For example, “reciet” gets corrected 


46 1. See his excellent tutorial at http://norvig.com/spell-correct .html. 
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to “recite” instead of “receipt”, and “adres” gets corrected to “acres” not “address”. We can fix this 
problem by learning the parameters of the noise model based on a labeled corpus of (z,2) pairs 
derived from actual spelling errors. One possible way to get such a corpus is to look at web search 
behavior: if a user types query qı and then quickly changes it to q2 followed by a click on a link, it 
suggests that q2 is a manual correction for qı, so we can set (z = q2,y = q1). This heuristic has been 
used in the Etsy search engine.” It is also possible to manually collect such data (see e.g., [Hag+17]), 
or to algorithmically create (z, y) pairs, where y is an automatically generated misspelling of z (see 
e.g., [ECM18}]). 


29.3.3.2 HMM model 


The baseline model can work well, but has room for improvement. In particular, many errors will 
be hard to correct without context. For example, suppose the user typed “advice”: did they mean 
“advice” or “advise”? It depends on whether they intended to use a noun or a verb, which is hard 
to tell without looking at the sequence of words. To do this, we will “upgrade” our model to an 
HMM. We just have to replace our independence prior p(21:r) = [ [; p() by a standard first-order 
language model on words, p(21.7) = [[, p(zt|zt+-1). The parameters of this model can be estimated 
by counting bigrams in a large corpus of “clean” text (see Section 2.6.3.1). The observation model 
ply:|z+) can remain unchanged. 

Given this model, we can compute the top N most likely hidden sequences in O(NTK7) time, 
where K is the number of hidden states, and T is the length of the sequence, as explained in 
Section 8.2.7.5. In a naive implementation, the number of hidden states K is the number of words in 
the vocabulary, which would make the method very slow. However, we can exploit sparsity of the 
likelihood function (i.e., the fact that p(y|z) is 0 for most values of z) to generate small candidate 
lists of hidden states for each location in the sequence. This gives us a sparse belief state vector œ+. 


29.3.3.3 Extended HMM model 


We can extend the HMM model to handle higher level errors, in addition to misspellings of individual 
words. In particular, [LDZ11; LDZ12] proposed modeling the following kinds of errors: 


e Two words merged into one, e.g., “home page” — “homepage”. 
e One word split into two, e.g., “illinoisstate” — “illinois state”. 


e Within-word errors, such as substitution, transposition, insertion and deletion of letters, as we 
discussed in Section 29.3.3.2. 


We can model this with an HMM, where we augment the state space with a silent state, that 
does not emit any symbols. Figure 29.11 illustrates how this model can “denoise” the observed query 
goverment home page of illinoisstate” into the correctly formulated query “government homepage of 
illinois state”. 

An alternative to using HMMs is to use supervised learning to fit a sequence-to-sequence translation 
model, using RNNs or transformers. This can work very well, but often needs much more training 
data, which can be problematic for low-resource languages [ECM18]. 


2. See this blogpost by Mohit Nayyar for details: https: //codeascraft.com/2017/05/01/ 
modeling-spelling-correction-for-search-at-etsy/. 
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query: goverment home page of illinoisstate 


emission: illinoisstate 


state path: illinois state 


S1 S2 S5 
type: substitution type: merging type: NULL type: substitution type: splitting 


Figure 29.11: Illustration of an HMM applied to spelling correction. The top row, labeled “query”, represents 
the search query y1:r typed by the user, namely “goverment home page of illinoisstate”. The bottom row, 
labeled “state path”, represents the most probable assignment to the hidden states, z1:r, namely “government 
homepage of illinois state”. (The NULL state is a silent state, that is needed to handle the generation of two 
tokens from a single hidden state.) The middle row, labeled “emission”, represents the words emitted by each 
state, which match the observed data. From Figure 1 of [LDZ11]. 


29.4 HMMs: parameter learning 


=“ In this section, we discuss how to compute a point estimate or the full posterior over the model 
= parameters of an HMM given a set of partially observed sequences. 


29.4.1 The Baum-Welch (EM) algorithm 


32 Tn this section, we discuss how to compute an approximate MLE for the parameters of an HMM 


using the EM algorithm which is an iterative bound optimization algorithm (see Section 6.6.3 for 


34 details). When applied to HMMs, the resulting method is known as the Baum-Welch algorithm 


[Bau+70]. 


— 29.4.1.1 Log likelihood 


39 The joint probability of a single sequence is given by 


P(Yi:7, 21:7\9) = [p(z1|7)] lig 24| Zt LA (29.23) 


T 
Tros 
t=1 
7 K g1) migi i=j =k) IIT By) I(zt=k) 
= k (y:|Bx) (29.24) 
k=1 t=1k=1 


t=2 j=1 k=1 
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e 
Zn, Tn 


e o o YnT, 


B 


Figure 29.12: HMM with plate notation. A are the parameters for the state transition matriz p(zt|zt—-1) and 
B are the parameters for the discrete observation model p(xt|zt). Tn is the length of the n’th sequence. 


where 8 = (n, A,B). Of course, we cannot compute this objective, since z1;r is hidden. So instead 
we will optimize the expected complete data log likelihood, where expectations are taken using the 
parameters from the previous iteration of the algorithm: 


Q(0, 0°") _ Up(z1.7lyr.7 0°!) [log p(yı:T, 21:7|8)] (29.25) 


This can be easily summed over N sequences. See Figure 29.12 for the graphical model. 

The above objective is a lower bound on the observed data log likelihood, log p(y1.r|@), so the 
entire procedure is a bound optimization method that is guaranteed to converge to a local optimum. 
(In fact, in the case of HMMs, it can can be shown to converge to (close to) one of the global optima 
[YBW15].) 


29.4.1.2 E step 


Let Ajk = pla = kla-1 = j) be the K x K transition matrix. For the first time slice, let 
Tk = p(z1 = k) be the initial state distribution. Let 0; represent the parameters of the observation 
model for state k. 

To compute the expected sufficient statistics, we first run the forwards-backwards algorithm on 
each sequence (see Section 8.2.4). This returns the following node and edge marginals: 


Yn tli) = pla = jlYn, Ta 0") (29.26) 
Ent, k) = p(Zt-1 = Js at = k\Yn, 1T, ; a) (29.27) 


where Tn is the length of sequence n. We can then derive the expected counts as follows (note that we 
pool the sufficient statistics across time, since the parameters are tied, as well as across sequences): 


N N Th N Ty 
= [NE] = So ana (k), EIN) = X Sone), ELM) = X So enti.) (29.28) 
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Given the above quantities, we can compute the expected complete data log likelihood as follows: 


K K 
Q(0,0°) = SCE [Ny] log ae + 5 XE LN jx] log Ajk 
k=1 j=l kel 
Nop Tn K 
+ 2al z = k|yn 1r, O°) log p(ynt|Ox) (29.29) 


29.4.1.3 M step 


We can estimate the transition matrix and initial state probabilities by maximizing the objective 
subject to the sum to one constraint. The result is just a normalized version of the expected counts: 


a [N z (Ni 
-E ial it, = [Nil (29.30) 
Dw [Nik] N 
This result is quite intuitive: we simply add up the expected number of transitions from j to k, and 


divide by the expected number of times we transition from j to anything else. 
For a categorical observation model, the expected sufficient statistics are 


pe 


N Ta N 
i [Mro] = 5 Yn t(k)I (Yn, t = v) = 5 5 Yn,t(K) (29.31) 
n=1 t=1 n=1 t:yn,t=v 


The M step has the form 


z [Mko] 
z [Nk] 


Bry = (29.32) 


This result is quite intuitive: we simply add up the expected number of times we are in state k and 
we see a symbol v, and divide by the expected number of times we are in state k. See Algorithm 11 
for the pseudocode. 

For a Gaussian observation model, the expected sufficient statistics are given by 


N Th 
Yntlk) Ynt Ye = X >> nalk) Yn tY (29.33) 


n=i1 t=1 n=1 t=1 


A y 
Îr = IN (29.34) 
n yy", — E |N;] ffl 

S, — YY — E [Ni] Arhi (29.35) 


E [N] 


45 In practice, we often need to add a log prior to these estimates to ensure the resulting Š, estimate is 
46 well-conditioned. See [Mur22, Sec 4.5.2] for details. 
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Algorithm 42: Baum Welch algorithm for (discrete observation) HMMs 


1 Initialize parameters 0 
for each iteration do 


2 

3 // Estep 

4 Initialize expected counts: E [N}] = 0, E[Njx] = 0, E [Mke] = 0 
5 

6 


for each datacase n do 
Use forwards-backwards algorithm on Yn to compute yp, and En, + 
(Equations 29.26-29.27) 


= E [Ni] + Xia Yne (k) 

J [N54] := E [Nk] F Dan Enli, k) 

J [Mko] := E [Mko] + ae nalk) 

10 // M step 

11 | Compute new parameters 0 = (A,B, m) using Equations 29.30 


x 
Se 

= 
| 


29.4.1.4 Initialization 


As usual with EM, we must take care to ensure that we initialize the parameters carefully, to minimize 
the chance of getting stuck in poor local optima. There are several ways to do this, such as 


e Use some fully labeled data to initialize the parameters. 


e Initially ignore the Markov dependencies, and estimate the observation parameters using the 
standard mixture model estimation methods, such as K-means or EM. 


e Randomly initialize the parameters, use multiple restarts, and pick the best solution. 


Techniques such as deterministic annealing [UN98; RRO1a] can help mitigate the effect of local 
minima. Also, just as K-means is often used to initialize EM for GMMs, so it is common to initialize 
EM for HMMs using Viterbi training. The Viterbi algorithm is explained in Section 8.2.7, but 
basically it is an algorithm to compute the single most probable path. As an approximation to 
the E step, we can replace the sum over paths with the statistics computed using this single path. 
Sometimes this can give better results [AG11]. 


29.4.1.5 Example: casino HMM 


In this section, we fit the casino HMM from Section 8.2.1. The true generative model is shown in 
Figure 29.13a. We used this to generate 4 sequences of length 5000, totalling 20,000 observations. 
We initialized the model with random parameters. We ran EM for 200 iterations and got the results 
in Figure 29.13b. We see that the learned parameters are close to the true parameters, modulo label 
switching of the states, due to unidentifiability. 


29.4.2 Parameter estimation using SGD 


Although the EM algorithm is the “traditional” way to fit HMMs, it is inherently a batch algorithm, 
so it does not scale well to large datasets (with many sequences). Although it is possible to extend 
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(a) (b) (c) 


Figure 29.13: Illustration of the casino HMM. (a) True parameters used to generate the data. (b) Estimated 
parameters using EM. (c) Estimated parameters using SGD. Generated by casino_hmm_ training.ipynb. 


EM Algorithm Stochastic Gradient Descent Full Batch Gradient Descent 


O 25 50 75 100 125 150 175 200 oO 25 50 75 100 1235 150 175 200 o 25 50 75 100 125 150 175 200 
(a) (b) (c) 


Figure 29.14: Average negative log likelihood per learning step the casino HMM. (a) EM. (b) SGD with 


2 minibatch size 1. (b) Full batch gradient descent. Generated by casino_hmm_ training. ipynb. 


~~ bound optimization to the online case (see e.g., [Mail5]), this can take a lot of memory. 


A simple alternative is to optimize log p(yi.r|@) using SGD. We can compute this objective using 


~ the forwards algorithm, as shown in Equation (8.9): 


T T 
log p(ys-r|8) = X` log p(yrlyn:e-1,0) = Slog Z: (29.36) 
t=1. 


t=1 
where the normalization constant for each time step is given by 


K 


Z Ê v(yelyr:e-1) = X pla = Jlun- ply = j) (29.37) 
j=1 


42 Of course, we need to ensure the transition matrix remains a valid row stochastic matrix, i.e., that 
43 0< Aij < 1 and >> j Aij = 1. Similarly, if we have categorical observations, we need to ensure Bj, 


is a valid row stochastic matrix, and if we have Gaussian observations, we need to ensure Nz is a 


45 valid psd matrix. These constraints are automatically taken care of in EM. When using SGD, we can 
46 reparameterize to an unconstrained form, as proposed in [BC94]. 
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29.4.2.1 Example: casino HMM 


In this section, we use SGD to fit the casino HMM using the same data as in Section 29.4.1.5. We 
show the learning the learning curves in Figure 29.14. We see that SGD converges slightly more 
slowly than EM, and is not monotonic in how it decreases the NLL loss, even in the full batch case. 
However, the final parameters are similar, as shown in Figure 29.13. 


29.4.3 Parameter estimation using spectral methods 


Fitting HMMs using maximum likelihood is difficult, because the log likelihood is not convex. Thus 
there are many local optima, and EM and SGD can give poor results. An alternative approach is to 
marginalize out the hidden variables, and work instead with predictive distributions in the visible 
space. For discrete observation HMMs, with observation matrix Bj, = p(y = k|zt = j), such a 
distribution has the form 


[ule = plyt = klyi—1) (29.38) 


This is called a predictive state representation [SJR04]. 

Suppose there are m possible hidden states, and n possible visible symbols, where n > m. One can 
show [HKZ12; Joh12] that the PSR vectors lie in a subspace in R” with a dimensionality of m < n. 
Intuitively this is because the linear operator A defining the hidden state update in Equation (8.13), 
combined with the mapping to observables via B, induces low rank structure in the output space. 
Furthermore, we can estimate a basis for this low rank subspace using SVD applied to the observable 
matrix of co-occurrence counts: 


[Polis = p(y: = i, y-1 = J) (29.39) 


We also need to estimate the third order statistics 
[Pslije = p(y = i, ¥e-1 = j, Yt-2 = K) (29.40) 


Using these quantities, it possible to perform recursive updating of our predictions while working 
entirely in visible space. This is called spectral estimation, or tensor decomposition [HKZ12; 
AHK12; Rod14; Ana+14; RSG17]. 

We can use spectral methods to get a good initial estimate of the parameters for the latent variable 
model, which can then be refined using EM (see e.g., [Smi+00]). Alternatively, we can use them “as 
is”, without needing EM at all. See [Mat14] for a comparison of these methods. See also Section 29.8.2 
where we discuss spectral methods for fitting linear dynamical systems. 


29.4.4 Bayesian HMMs 


MLE methods can easily overfit, and can suffer from numerical problems, especially when sample 
sizes are small. In this section, we briefly discuss some approaches to inferring the posterior over the 
parameters, p(@|D). By adopting a Bayesian approach, we can also allow the number of states to be 
unbounded by using a hierarchical Dirichlet process (Section 31.2) to get a HDP-HMM [Fox+08]. 

There are various algorithms we can use to perform posterior inference, such as variational Bayes 
EM [Bea03] or blocked Gibbs sampling (see Section 29.4.4.1), that alternates between sampling 
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latent sequences 27,71. using the forwards filtering backwards sampling algorithm (Section 8.2.8) 
and sampling the parameters from their full conditionals, p(@|y1-r, Zf.r 1:n). Unfortuntely, the high 
correlation between z and 0 makes this coordinate-wise approach rather slow. 

A faster approach is to marginalize out the discrete latents (using the forwards algorithm), and 
then to use MCMC [Fot+14] or SVI [Obe+19] to sample from the following log posterior: 


N 
log p(8, D) = log p(@) + X` log p(y1:T,nl0) (29.41) 


n=1 


This is a form of “collapsed” inference. 


29.4.4.1 Blocked Gibbs sampling for HMMs 


This section is written by Xinglong Li. 


In this section, we discuss Bayesian inference for HMMs using blocked Gibbs sampling (c.f., [Sco02]). 
For the observation model, we consider the first-order auto-regressive HMM model in Section 29.2.5, 
so p(Yi|Yr—1, 2% = J, 9) = N (y |EjY:1 + Hj, Xj). For a model with K hidden states, the unknown 
parameters are 0 = {7,A,Ej,...,Ex,51,..., 5g}, where we assume (for notational simplicity) 
that uj of each autoregressive model is known, and that we condition the observations on y1. 

We alternate between sampling from p(21.r|y1:7, 0) using the forwards filtering backwards sampling 
algorithm (Section 8.2.8), and sampling from p(6|21-7, y1.r). Sampling from p(6|z1-7, Yı:r) is easy 


~~ if we use conjugate priors. Here we use Dirichlet prior for m and each row A,. of the transition 
— matrix, and choose matrix normal inverse Wishart distribution as the prior for {E;,;} of each 
=~ autoregressive model, similar to Bayesian multivariate linear regression Section 15.2.8. In particular, 


= the prior distributions of 0 are: 


w Iw |W |W |% 
z eIl 


n ~ Dit(&_) A,. ~ Dit(&a) (29.42) 
E; ~ IW(Y;, V5) E;|E; ~ MN (Mj, 55, Vs) (29.43) 


36 where G,,,=0, /K and @4,=@,4 /K. The log prior probability is 


AG ag X (5 +Ny +1 1 
= T l j y JLZ ý. y! 
log p(0) ao p LST + py p oe Aik > ( 5 log|£;| + 5 trace (¥, X; )) 
efi 1 
-5 (5 log|£;| + trace((Ej— Mj)"; *(E;— Mj) v5) (29.44) 


a. 
Il 
m 


Given y1:7 and 21,7 we denote N; = ool (ze = j) and Nj, = SEI (zt = j, Zt41 = k). The 
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joint likelihood is 


K 


K K 
log p(y1:T, zı:r|0) = c + > I (z1 =k) log Tk + Ly 5 Nyx log Ajk 
k=1 j=1 k=1 


K 
1 1 = 
= 5 G log|£;| + 3 (Ye — Byye-a — uj) E7 (y: — Ej- — 1) 


j= 


BR 
S. 


24> 


(29.45) 
-e+ Ma pienas s en 
j=l k=l 
AN; i - . i 
= ne log|£;| + 5(¥j — By¥j)' Ej (Y; - B,¥,)) (29.46) 


where Y; = [ye—-M;]z,=j and Y; = [ys—1]z,=;, and it can be seen that Y; ~ MN (Yj1E;¥;, X; Iy,). 

It is obvious from log p(@) + log p(y1:7, Z1:7|@) that the posterior of m and Aj. are still a Dirichlet 
distributions. It can also be shown that the posterior distributions of {E;, X; } are still matrix normal 
inverse Wishart distributions, whose hyper parameters can be directly obtained by replacing Y, A, X 
in Equation (15.100) with Ya, E; and Y; respectively. To summarize, the posterior distribution 
p(0|z1:T, Yı:T) is: 


t|z1.-7 ~ Dir(&), rk =A, /K+1(z1 =k) (29.47) 
A; [zir ~ Dir(@y), QA;,k =X /K + Nik (29.48) 
Bylzur, yur ~ IW(P;, Ẹ; j) E;j| £j, zir, yur ~ MN (Mj, 2j, V5) (29.49) 


29.5 HMMs: Generalizations 


In this section, we discuss various extensions of the vanilla HMM introduced in Section 29.2. 


29.5.1 Hidden semi-Markov model (HSMM) 
In a standard HMM (Section 29.2), the probability we remain in state i for exactly d steps is 
p(d; = d) = (1 — Ay) AJ, x exp(dlog Aji) (29.50) 


where A;; is the self-loop probability. This is called the geometric distribution. However, this 
kind of exponentially decaying function of d is sometimes unrealistic. 

A simple way to model non-geometric waiting times is to replace each state with n new states, 
each with the same emission probabilities as the original state. For example, consider the model in 
Figure 29.15(a). Obviously the smallest sequence this can generate is of length n = 4. Any path 
of length d through the model has probability p4~"(1 — p)”; multiplying by the number of possible 
paths we find that the total probability of a path of length d is 


E G 7 1 p"(1— p)” (29.51) 


n—1 
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Figure 29.15: (a) A Markov chain with n = 4 repeated states and self loops. (b) The resulting distribution 
over sequence lengths, for p = 0.99 and various n. Generated by hmm_ self_loop_ dist.ipynb. 


This is equivalent to the negative binomial distribution. By adjusting n and the self-loop probabilities 
p of each state, we can model a wide range of waiting times: see Figure 29.15(b). 

A more general solution is to use a semi- Markov model, in which the next state not only depends 
on the previous state, but also on how how long we’ve been in that state. When the state-space is 
not observed directly, the result is called a hidden semi-Markov model (HSMM), a variable 
duration HMM, or an explicit duration HMM [Yu10]. 

One way to represent a HSMM is to use the graphical model shown in Figure 29.16(a). The 
d; € {1,...,D} node is a state duration counter, where D is the maximum duration of any state. 
When we first enter state j, we sample d; from the duration distribution for that state, dy ~ p;(-). 


26 Thereafter, d; deterministically counts down until d; = 1. More precisely, we define the following 
27 CPD: 
D;(dďd') ifd=1 
pldi = d'd- =d,z=j) =< 1 ifd =d—landd>1 (29.52) 
0 otherwise 


= Note that Dj;(d) could be represented as a table (a non-parametric approach) or as some kind 
= of parametric distribution, such as a Gamma distribution. If D;(d) is a (truncated) geometric 
= distribution, this emulates a standard HMM. 


While d; > 1, the state z; is not allowed to change. When d; = 1, we make a stochastic transition 


= to a new state. (We assume A,;; = 0.) More precisely, we define the state CPD as follows: 
1 if d > 0 and j = k 
plz = k|z-1 = j,di-1 =d) = 4 Aj, ifd=1 (29.53) 
0 otherwise 


This ensures that the model stays in the same state for the entire duration of the segment. At each 


43 step within this segment, an observation is generated. 


HSMM s are useful not only because they can model the duration of each state explicitly, but also 


45 because they can model the distribution of a whole subsequence of observations at once, instead 
46 of assuming all observations are generated independently at each time step. That is, they can use 
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dy-1 di diyi d =3 dy =2 d3 =1 d4 = 2 ds =1 


Zt-1 At Zt+1 


Yt-1 Yt Yt+1 
(a) 


Figure 29.16: Encoding a hidden semi-Markov model as a DPGM. (a) di is a deterministic down counter 
(duration variable). Each observation is generated independently. (b) Similar to (a), except now we generate 
the observations within each segment as a block. In this figure, we represent the non-Markovian dependencies 
between the observations within each segment by using undirected edges. We represent the conditional 
independence between the observations across different segments by disconnecting y1:3 from ya:5; this can be 
enforced by “breaking the link” whenever di = 1 (representing the end of a segment). 


likelihood models of the form p(Ytt+1-1|zt = k, dt = l), which generate | correlated observations if 
the duration in state k is for l time steps. This approach, known as a segmental HMM, is useful 
for modeling data that is piecewise linear, or shows other local trends [ODK96]. We can also use an 
RNN to model each segment, resulting in an RNN-HSMM model [Dai+17]. 

More precisely, we can define a segmental HMM as follows: 


T 
ply, z, d) = |p(21)p(dı|z1) | | Pllzt-1, de—1) (del ze, di—1) | p(ylz, d) (29.54) 


t=2 


In a standard HSMM, we assume 


plyl|z, d) = [In (yelzt) (29.55) 


so the duration variables only determine the hidden state dynamics. To define p(y|z,d) for a 
segmental HMM, let us use s; and e; to denote the start and end times of segment 7. This sequence 
can be computed deterministically from d using sı = 1, s; = s;-1 + d,,_,, and e; = s; + ds; — 1. We 
now define the observation model as follows: 


p(y|z, d) -TLx (Ysi:e: 


Zs,,ds,) (29.56) 


See Figure 29.16(b) for the DPGM. 
If we use an RNN for each segment, we have 


25,1 ds; = Tx P(YtlYsi:t—1) Zs) = Il P(yelhe, Zs;) (29.57) 


t=s; t=s; 


P(Ys; 6: 
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Figure 29.17: An example of an HHMM for an ASR system which can recognize 3 words. The top level 
represents bigram word probabilities. The middle level represents the phonetic spelling of each word. The 
bottom level represents the subphones of each phone. (It is traditional to represent a phone as a 3 state HMM, 
representing the beginning, middle and end; these are known as subphones.) Adapted from Figure 7.5 of 
[JM00]. 


where h; is the hidden state that is deterministically updated given the previous observations in this 
sequence. 

As shown in [Chil4], it is possible to compute p(z, d:|y1-r) in O(TK? + TKD) time, where T is 
the sequence length, K is the number of states, and D is the maximum duration of any segment. 
In [Dai+17], they show how to train an approximate inference algorithm, based on a mean field 


27 approximation q(z,dly) = [], ¢(z|y)¢a(d:ly), to compute the posterior in O(TK + TD) time. 


= 29.5.2 Hierarchical HMMs 
31 A hierarchical HMM (HHMM) [FST98] is an extension of the HMM that is designed to model 


domains with hierarchical structure. Figure 29.17 gives an example of an HHMM used in automatic 
speech recognition, where words are composed of phones which are composed of subphones. We 
can always “flatten” an HHMM to a regular HMM, but a factored representation is often easier to 
interpret, and allows for more efficient inference and model fitting. 

HHMMs have been used in many application domains, e.g., speech recognition [Bil01], gene finding 
[Hu+00], plan recognition [BVW02], monitoring transportation patterns [Lia+07], indoor robot 
localization [TMK04], etc. HHMMs are less expressive than stochastic context free grammars (SCFGs) 
since they only allow hierarchies of bounded depth, but they support more efficient inference. In 
particular, inference in SCFGs (using the inside outside algorithm, [JM08]) takes O(T*) whereas 


41 inference in an HHMM takes O(T) time [MP01; WM12]. 


We can represent an HHMM as a directed graphical model as shown in Figure 29.18. Q£ represents 


43 the state at time t and level £. A state transition at level £ is only “allowed” if the chain at the level 


below has “finished”, as determined by the | node. (The chain below finishes when it chooses to 
enter its end state.) This mechanism ensures that higher level chains evolve more slowly than lower 


46 level chains, i.e., lower levels are nested within higher levels. 
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x 


Nf 


Figure 29.18: An HHMM represented as a DPGM. Z£ is the state at time t, level £; Fi=1 if the HMM at 
level £ has finished (entered its exit state), otherwise Ff = 0. Shaded nodes are observed; the remaining nodes 
are hidden. We may optionally clamp F£ = 1, where T is the length of the observation sequence, to ensure 
all models have finished by the end of the sequence. From Figure 2 of [MP01]. 


A variable duration HMM can be thought of as a special case of an HHMM, where the top level is 
a deterministic counter, and the bottom level is a regular HMM, which can only change states once 
the counter has “timed out”. See [MPO1] for further details. 


29.5.3 Factorial HMMs 


An HMM represents the hidden state using a single discrete random variable z; € {1,..., K}. To 
represent 10 bits of information would require K = 21° = 1024 states. By contrast, consider a 
distributed representation of the hidden state, where each zm € {0,1} represents the m’th bit 
of the tth hidden state. Now we can represent 10 bits using just 10 binary variables. This model is 
called a factorial HMM [GJ97]. 

More precisely, the model is defined as follows: 


p(z,y) = II rentan 


where P(ztm = k|zt-1,m = J) = Amjx is an entry in the transition matrix for chain m, p(Zim = 
k|zom) = p(21m = k) = Tmçr, is the initial state distribution for chain m, and 


P(yt|2t) (29.58) 


M 
P( Yel Zz) = N (u 5 W mim, =) (29.59) 


m=1 
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(b) 


Figure 29.19: (a) A factorial HMM with 3 chains. (b) A coupled HMM with 3 chains. 


is the observation model, where Zim is a 1-of-K encoding of ztm and Wm isa Dx K matrix (assuming 
y: E€ RP). Figure 29.19a illustrates the model for the case where M = 3. 

An interesting application of FHMMs is to the problem of energy disaggregation [KJ12a]. 
In this problem, we observe the total energy usage of a house at each moment in time, i.e., the 
observation model has the form 


M 
P(yelze) =N (yl XD wmzms 0°) (29.60) 
m=1 


=“ where Wm is the amount of energy used by device m, and zim = 1 if device m is being used at time t 
= and Zim = 0 otherwise. The transition model is assumed to be 


Aor if Zt-1m = 0 
Hem = team) = | Dre (29.61) 


Ai, if Zt—1,m = 1 


33 We do not know which devices are turned on at each time step (i.e., the zim are hidden), but by 
34 applying inference in the FHMM over time, we can separate the total energy into its parts, and 
35 thereby determine which devices are using the most electricity. 


Unfortunately, conditioned on yz, all the hidden variables are correlated (due to explaining away 
the common observed child y+). This make exact state estimation intractable. However, we can 


38 derive efficient approximate inference algorithms, as we discuss in Supplementary Section 10.3.2. 


40 29.5.4 Coupled HMMs 


42 If we have multiple related data streams, we can use a coupled HMM [Bra96]. This is a series of 


HMMs where the state transitions depend on the states of neighboring chains. That is, we represent 
the conditional distribution for each time slice as 


p(t, Ytlzt-1) = [| 2@emlzim)P(Zem |Zt-1,m—1:m41) (29.62) 
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Figure 29.20: The BATnet DBN. The transient nodes are only shown for the second slice, to minimize clutter. 
The dotted lines are used to group related variables. Used with kind permission of Daphne Koller. 


with boundary conditions defined in the obvious way. See Section 29.5.4 for an illustration with 
M = 3 chains. 

Coupled HMMs have been used for various tasks, such as audio-visual speech recognition 
[Nef+02], modeling freeway traffic flows [K M00], and modeling conversational interactions between 
people [Bas+0]]. 

However, there are two drawbacks to this model. First, exact inference takes O(T(K™)?), as in a 
factorial HMM; however, in practice this is not usually a problem, since M is often small. Second, 
the model requires O(M K*) parameters to specify, if there are M chains with K states per chain, 
because each state depends on its own past plus the past of its two neighbors. There is a closely 
related model, known as the influence model |Asa00], which uses fewer parameters, by computing 
a convex combination of pairwise transition matrices. 


29.5.5 Dynamic Bayes nets (DBN) 


A dynamic Bayesian network (DBN) is a way to represent a stochastic process using a directed 
graphical model [Mur02]. (Note that the network is not dynamic (the structure and parameters are 
fixed), rather it is a network representation of a dynamical system.) A DBN can be considered as a 
natural generalization of an HMM. 

An example is shown in Figure 29.20, which is a DBN designed to monitor the state of a simulated 
autonomous car known as the “Bayesian Automated Taxi”, or “BATmobile” [For+95]. To define 
the model, you just need to specify the structure of the first time-slice, the structure between two 
time-slices, and the form of the CPDs. For details, see [KF 09a]. 


29.5.6 Changepoint detection 


In this section, we discuss changepoint detection, which is the task of detecting when there are 
“abrupt” changes in the distribution of the observed values in a time series. We focus on the online 
case. (For a review of offline methods to this problem, see e.g., [AC17; TOV 18]. (See also [BW20] for 
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Figure 29.21: Illustration of Bayesian online changepoint detection (BOCPD). (a) Hypothetical segmentation 
of a univariate time series divided by changepoints on the mean into three segments of lengths gi = 4, g2 = 6, 
and an undetermined length for g3 (since it the third segment has not yet ended). From Figure 1 of [AM07]. 
Used with kind permission of Ryan Adams. 


a recent empirical evaluation of various methods, focused on the 1d time series case.) 

The methods we discuss can (in principle) be used for high-dimensional time series segmentation. 
Our starting point is the hidden semi-Markov models (HSMM) discussed in Section 29.5.1. This 
is like an HMM in which we explicitly model the duration spent in each state. This is done by 
augmenting the latent state z; with a duration variable d which is initialized according to a duration 
distribution, d¢ ~ D.,(-), and which then counts down to 1. An alternative approach is to add a 
variable r;{0,1,..., } which encodes the run length for the current state; this starts at 0 whenever a 
new segment is created, and then counts up by one at each step. The transition dynamics is specified 
by 


A(re1 + 1) if Tt = 0 
p(re|rt—-1) = 1- H(ri—ı + 1) if Te =Tt-1 + 1 (29.63) 
0 otherwise 


where H(r) is a hazard function: 


Pg(T) 

FAT) = = pelt) (29.64) 
where p,(t) is the probability of a segment of length t. See Figure 29.21 for an illustration. If we 
set pg to be a geometric distribution with parameter A, then the hazard function is the constant 
A(r) = 1/A. 

The advantage of the run-length representation is that we can define the observation model for a 
segment in a causal way (that only depends on past data): 


P(Y Yrt-1; rt = T, zt = k) = plyt Yt-r:t-1; zt = k) = J voeo ze=k)dn (29.65) 


where 77 are the parameters that are “local” to this segment. This called the underlying predictive 
model or UPM for the segment. The posterior over the UPM parameters is given by 
t—1 
P(mlYe—rt—1, 2 = k) x pinla =k) [| pwn) (29.66) 


i=t-r 
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where we initialize the prior for 7 using hyper-parameters chosen by state k. If the model is conjugate 
exponential, we can compute this marginal likelihood in closed form, and we have 


TE = p(yslys—rse—1, 2 = k) = piyd") (29.67) 


where pre are the parameters of the posterior predictive distribution at time t based on the last r 
observations (and using a prior from state k). 

In the special case in which we have K = 1 hidden states, then each segment is modeled 
independently, and we get a product partition model [BH92]: 


P(YIT) = P(Ysi:e1) «+» P(Ysnzen ) (29.68) 


where s; and e; are the start and end of segment 7, which can be computed from the run lengths r. 
(We initialize with ro = 0.) Thus there is no information sharing between segments. This can be 
useful for timeseries in which there are abrupt changes, and where the new parameters are unrelated 
to the old ones. 

Detecting the locations of these changes is called changepoint detection. An exact online 
algorithm for solving this task was proposed in [FL07] and independently in [AM07]; in the latter 
paper, they call the method Bayesian online changepoint detection or BOCPD. We can 
compute a posterior over the current run length recursively as follows: 


P(relYrt) & P(YslYrt-1,7e)P(Te|Yr:t-1) (29.69) 


where we initialize with p(ro = 0) = 1. The marginal likelihood p(y:|y1-z-1, r+) is given by Equa- 
tion (29.65) (with z = 1 dropped, since there is just one state). The prior predictive is given by 


P(re|Y1:t-1) =) p( (re|T+-1)P(Tt-1|Y1t-1) (29.70) 


Tt—1 


The one step ahead predictive distribution is given by 


Plyilyrt) = X pyi Yre reply) (29.71) 


Tt 


29.5.6.1 Example 


We give an example of the method in Figure 29.22 applied to a synthetic 1d dataset generated from 
a 4 state GMM. The likelihood is a univariate Gaussian, p(y:|) = N (ytl, o°), where o? = 1 is 
fixed, and p is inferred using a Gaussian prior. The hazard function is set to a geometric distribution 
with rate N/T, where N = 4 is the true number of change points and T = 200 is the length of the 
sequence. 


29.5.6.2 Extensions 


Although the above method is exact, each update step takes O(t) time, so the total cost of the 
algorithm is O(T?). We can reduce this by pruning out states with low probability. In particular, we 
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Figure 29.22: Illustration of BOCPD. (a) Synthetic data from a GMM with 4 states. Top row is the data, 
bottom row is the generating state. (b) Output of algorithm. Top row: Estimated changepoint locations. Bottom 
row: posterior predicted probability of a changepoint at each step. Generated by changepoint_ detection.ipynb. 


can use particle filtering (Section 13.2) with N particles, together with a stratified optimal resampling 
method, to reduce the cost to O(TN). See [FLO7] for details. 

In addition, the above method relies on a conjugate exponential model in order to compute the 
marginal likelihood, and update the posterior parameters for each r, in O(1) time. For more complex 
models, we need to use approximations. In [TBS13], they use variational Bayes (Section 10.2.3), and 
in [Mav16], they use particle filtering (Section 13.2), which is more general, but much slower. 

It is possible to extend the model in various other ways. In [FL11], they allow for Markov 
dependence between the parameters of neighboring segments. In [STR10], they use a Gaussian 
process (Chapter 18) to represent the UPM, which captures correlations between observations within 
the same segment. In [KJD18], they use generalized Bayesian inference (Section 14.1.3) to create a 
method that is more robust to model misspecification. 

In [Gol+17], they extend the model by modeling the probability of a sequence of observations, 
rather than having to make the decision about whether to insert a changepoint or not based on just 
the likelihood ratio of a single time step. 

In [AE+20], they extend the model by allowing for multiple discrete states, as in an HSMM. In 


a, addition, they add both the run length r; and the duration d to the state space. This allows the 


method to specify not just when the current segment started, but also when it is expected to end. 
In addition, it allows the UPM to depend on the duration of the segment, and not just on past 
observations. For example, we can use 


P(yelre, de,n) =N(y:|b(r:/d:)'n, 07) (29.72) 


where 0 < r;/d; < 1, and @() is a set of learned basis functions. This allows observation sequences 
for the same hidden state to have a common functional shape, even if the time spent in each state is 
different. 


42 29.6 Linear dynamical systems (LDS) 


44 In this section, we discuss linear-Gaussian state-space model (LG-SSM), also called linear 
45 dynamical system (LDS). This is a special case of an SSM in which the transition function and 
46 observation function are both linear, and the process noise and observation noise are both Gaussian. 
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29.7. LDS: APPLICATIONS 


29.6.1 Conditional independence properties 


The LDS graphical model is shown in Figure 29.1(a). This encodes the assumption that the hidden 
states are Markovian, and the observations are iid conditioned on the hidden states. All that remains 
is to specify the form of the conditional probability distributions of each node. 


29.6.2 Parameterization 


An LDS model is defined as follows: 


p(2t|Zt-1, Ur) = N (z|Fizi-1 + Beu + br, Qi) (29.73) 
P(yelZe, Ue) = N (ye |Hize + Diu: + di, Ri) (29.74) 


We often assume the bias (offset) terms are zero, in which case the model simplifies to 


p(2t|Zt-1, Ur) = N (z|Fizi-1 + Ben, Qi) (29.75) 
P(yr|Zt, Ut) = N (ys|Hez + Deut, Re) (29.76) 


Furthermore, if there are no inputs, the model further simplifies to 


P(2|Z-1) = N (24|Fi2t-1, Q:) (29.77) 
plyilzt) = N (yi |Hizt, R) (29.78) 


We can also write this as a structural equation model (Section 4.7.2): 


zt = Fizi- + qi (29.79) 

Ye = Hiz + ri (29.80) 
where q ~ N (0, Q+) is the process noise, and r; ~ N (0, R+) is the observation noise. 

Typically we assume the parameters 0; = (Fi, H:, B+, Di, Q:, Re) are independent of time, so the 
model is stationary. (We discuss how to learn the parameters in Section 29.8.) Given the parameters, 


we discuss how to perform online posterior inference of the latent states using the Kalman filter in 
Section 8.3.2, and offline inference using the Kalman smoother in Section 8.3.3. 


29.7 LDS: Applications 


In this section, we discuss some applications of LDS models. 


29.7.1 Object tracking and state estimation 


Consider an object moving in R?. Let the state be the position and velocity of the object, z, = 
(ue Ue Ut i). (We use u and v for the two coordinates, to avoid confusion with the state and 
observation variables.) This is a continuous time dynamical system, but if we use Euler discretization 
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with step size A, the dynamics become 


Ut 1 0A O Ut-1 
Ut — 0 1 0 A Ut—-1 
ùl loo 1 o| |.) 7% (ARBA) 
Ùe 00 0 1) Noa 
Zt F Zt-1 


where q; ~ N (0, Q). We will assume that the process noise is only added to the velocity components 
of the state, but not to the location. (This is known as a random accelerations model.) Thus 
Q= diag(0, 0,4, q). 

Now suppose that at each discrete time point we observe the location, corrupted by Gaussian noise. 
Thus the observation model becomes 


Ut 
y\ /1000 Ube 
(ae 0 1 0 Ut tr: (29.82) 
Yt H t 

LA 


where r; ~ N (0, R) is the observation noise. We see that the observation matrix H simply 
“extracts” the relevant parts of the state vector. 

Suppose we sample a trajectory and corresponding set of noisy observations from this model, 
(zT, yr) ~ p(z, y|0). (We use diagonal observation noise, R = diag(o?,03).) The results are 
shown in Figure 29.23(a). We can use the Kalman filter (Section 8.3.2) to compute p(z:|y1:4, 9) for 
each t,. (We initialize the filter with a vague prior, namely p(zo) = N(z0|0,10°I).) The results are 
shown in Figure 29.23(b). We see that the posterior mean (red line) is close to the ground truth, but 
there is considerable uncertainty (shown by the confidence ellipses). To improve results, we can use 
the Kalman smoother (Section 8.3.3) to compute p(z:|y1:7, 9), where we condition on all the data, 
past and future. The results are shown in Figure 29.23(c). Now we see that the resulting estimate is 
smoother, and the uncertainty is reduced. (The uncertainty is larger at the edges because there is 


39 less information in the neighbors to condition on.) 


29.7.2 Online Bayesian linear regression (recursive least squares) 


= Tn Section 15.2.1, we discuss how to compute p(w|c?, D) for a linear regression model in batch mode, 
— using a Gaussian prior of the form p(w) = N (w|u, ©). In this section, we discuss how to compute 


— this posterior online, by repeatedly performing the following update: 
p(w|Di+) x p(De|w)p(w|Pi2-1) (29.83) 
x p(Di|w)p(Di_-1|w) ...p(Di|w)p(w) (29.84) 


42 where Di = (uz, yz) is the t’th labeled example, and D.,-1 are the first t — 1 examples. (For brevity, 
43 we drop the conditioning on ø 


2.) We see that the previous posterior, p(w|D14—1), becomes the 


current prior, which gets updated by D, to become the new posterior, p(w|D1.+). This is an example 


45 of sequential Bayesian updating or online Bayesian inference. In the case of linear regression, this 
46 process is known as the recursive least squares or RLS algorithm. 
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=- true state o 20 
emissions 


O observed 
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18 | — filtered means 
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20] O observed 
true state 
18 | — smoothed means 


(a) () 


Figure 29.23: Illustration of Kalman filtering and smoothing for a 


30 10 15 20 25 30 


(c) 


linear dynamical system. (Repeated 


from Figure 8.8.) (a) Observations (green circles) are generated by an object moving to the right (true 
location denoted by blue squares). (b) Results of online Kalman filtering. Red cross is the posterior mean, 
circles are 95% confidence ellipses derived from the posterior covariance. (c) Same as (b), but using offline 


Kalman smoothing. The MSE in the trajectory for filtering is 3.13, an 
kf_ tracking. ipynb. 


d for smoothing is 1.71. Generated by 


— wo batch 
=-=- w batch 
— w 
+m 


W1 Wt a. 
2] 
o] 
Yi-1 Y 2 
3 -2 
a] 
6 
Tı Tı -e o z5 
(a) 


Figure 29.24: (a) A dynamic generalization of linear regression. (b) Ill 
algorithm applied to the model p(y|x,w) = N(y|wo + wız, o°). 


ustration of the recursive least squares 


We plot the marginal posterior of wo and 


wı vs number of data points. (Error bars represent E[w;|y1:z, 21:1] 4 
the data, we converge to the offline (batch) Bayes solution, represen 
represents the marginal posterior variance.) Generated by kf_linreg.ip 


E \/V [ws |yi:t, £1:+].) After seeing all 
ted by the horizontal lines. (Shading 
ynb. 
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We can implement this method by using a linear Gaussian state-space model(Section 29.6). The 
basic idea is to let the hidden state represent the regression parameters, and to let the (time-varying) 
observation model H; represent the current feature vector æ+.” That is, the observation model has 
the form 


p(ye\we) = N (ye |Hize, Re) = N (yilar we, 07) (29.85) 
If we assume the regression parameters do not change, the dynamics model becomes 
p(w,|wr1) = N (wiw, 0) = O( wt -= w1) (29.86) 


(If we do let the parameters change over time, we get a so-called dynamic linear model [Har90; 
WH97; PPCO9].) See Figure 29.24a for the model, and Supplementary Section 8.1.2 for a simplification 
of the Kalman filter equations when applied to this special case. 

We show a 1d example in Figure 29.24b. We see that online inference converges to the exact batch 
(offline) posterior in a single pass over the data. 

If we approximate the Kalman gain matrix by K; ~ m1, we recover the least mean squares or 
LMS algorithm, where m is the learning rate. In LMS, we need to adapt the learning rate to ensure 
convergence to the MLE. Furthermore, the algorithm may require multiple passes through the data 
to converge to this global optimum. By contrast, the RLS algorithm automatically performs step-size 
adaptation, and converges to the optimal posterior in a single pass over the data. 

In Section 8.9.3, we extend this approach to perform online parameter estimation for logistic 
regression, and in Section 17.5.2, we extend this approach to perform online parameter estimation 
for MLPs. 


— 29.7.3 Adaptive filtering 


29 Consider an autoregressive model of order D: 


Ye = WiYt-1 +++ + WDYt-D + €t (29.87) 


33 where e ~ N(0,1). The problem of adaptive filtering is to estimate the parameters w1.p given 
34 the observations yj.+. 


We can cast this as inference in an LG-SSM by defining the observation matrix to be H; = 


36 (y;1...Yyz—-p) and defining the state as z; = w. We can also allow the parameters to evolve over 
37 time, similar to Section 29.7.2. 


29.7.4 Timeseries forecasting 


— In Section 29.12 we discuss how to use LDS models to perform time series forecasting. 


3. It is tempting to think we can just set the input uz to the covariates æ+. Unfortunately this will not work, since 
the effect of the inputs is to add an offset term to the output in a way which is independent of the hidden state 


= (parameters). That is, since we have yz = Hiz: + Diut + d+ rt, if we set ut = x+ then the features get multiplied by 


46 the constant LDS parameter D+ instead of the hidden state z; containing the regression weights. 


IS 
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29.8. LDS: PARAMETER LEARNING 


29.8 LDS: parameter learning 


There are many approaches for estimating the parameters of state-space models. (In the control 
theory community, this is known as systems identification [Lju87].) In the case of linear dynamical 
systems, many of the methods are similar to techniques used to fit HMMs, discussed in Section 29.4. 
For example, we can use EM, SGD, spectral methods, or Bayesian methods, as we discuss below. 


29.8.1 EM for LDS 


This section is coauthored with Xinglong Li. 


If we only observe the output sequence, we can compute ML or MAP estimates of the parameters 
using EM. The method is conceptually quite similar to the Baum-Welch algorithm for HMMs 
(Section 29.4.1), except we use Kalman smoothing (Section 8.3.3) instead of forwards-backwards in 
the E step, and use different calculations in the M step. The details can be found in [SS82; GH96a]. 
Here we extend these results to consider the case where the HMM may have an optional input 
sequence U1.7. 

Our goal is to maximize the expected complete data log likelihood 


Q(0, 0—1) = E [log p(21:-7, y1-7|9) yur, urr, Ok-1] (29.88) 


where the expectations are taken wrt the parameters at the previous iteration k — 1 of the algorithm. 
(We initialize with random parameters, or by first fitting a simpler model, such as factor analysis.) 
For brevity of notations, we assume that the bias terms are included in D and B (i.e., the last entry 
of uz is 1). 

The log joint is given by the following: 


T 
1 T 
log p(z1-7, yrl8@) = — 5 (Su Hz, Du))"R-(y, — Hz — Duy) - 5 bes IRI (29.89) 
t=1 


T 


1 T-1 
5 (je Fz — Bu) Q(z; — Fzt_1 — Bu) =a log |Q] 
t=2 
(29.90) 
1 a 1 
— 5 (44 —m,)' Vi (21 — m1) — 3 log |V1| + const (29.91) 
where the prior on the initial state is p(z1) = N(z1|m1, V1). 
In the E step, we can run the Kalman smoother to compute par = E [zyr]; and Xayr = 
Cov [z:|y1-r], from which we can compute 2; = Hyp and 
P; =E [zzi lyr] = Ser + ey reir (29.92) 
We also need to compute the cross term 
Pisa = E [ze2{_ylyur] = Ertir + Harhar (29.93) 
where 
Eiir = Ve—1jt—-1 Gi-2 + Gi (Detar — FE4-1t-1) Gt_2 (29.94) 
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where G; is the backwards Kalman gain matrix defined in Equation (8.111). We initialize this 
using Xr rir = (I— Kp H)F™7_1)7-1, where Ky is the forwards Kalman gain matrix defined in 
Equation (8.72). 
We can derive the M step as follows, using standard matrix calculus. We denote Aout = [H, D], 
E a Aayn = = [F, B], Ldyn,t = = BRE ttle ls and 


Zout, = [2 , Uz 
P: rul Pı Zul 
P = k i P = F E hin 29. 
out,t ( T TI]> dyn,t wul wul ( 9 95) 
e Output matrices: 


T T 
= 2 R- Yiu, t -> R- * AoutPout, it = =0 (29.96) 


tl t=1 


0Q 
OAout 


=I 


Abut = (>: T (> Paws) (29.97) 
t=1 t=1 


e Output noise covariance: 


T 
0Q Aout = ACY T 1 ew 4 ew new 
( = e) = 5) R 3 X (ms — 2A put Lout ty, + Agar Pout A =) =0 (29.98) 


= out out 
OR-! rat 
1 T 
Rew = 2 uui = Abit Žout,tY ) (29.99) 
e System dynamics matrices: 
Te 
— = — D Q7! [Piti 2rul | + 5 CO Aga Pane =0 (29.100) 
dyn t—2 t=2 
T Za 
Adyn = (£r. t1; ZU, n) (>: Payn, :) (29.101) 
t=2 


e State noise covariance: 


Q(Aayn = Ai) T-1. IŠ 


=a = Q SOP: 2Adyn [Piit wed; | + ART Payat Ag) 
ðQ 2 2 om 
(29.102) 
T-1 1 = new aT 
= +z Q -3 5 (Pi — Agon [Pi-12, 422; ]) = 0 (29.103) 
t=2 
T 
QY = Tol =S (P: — Aisn [Pe 1t ê] |) (29.104) 
t=2 
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29.8. LDS: PARAMETER LEARNING 


e Initial state mean: 


0Q : = 
a = (21 —m)V;1=0 (29.105) 
mr = 5, (29.106) 


e Initial state covariance: 


o vi 1 7 F 
a = x 5 (Pi żim! — m2, +mım!)=0 (29.107) 
1 
yi =P; — 2,31 (29.108) 


Note that computing these expected sufficient statistics in the inner loop of EM takes O(T) 
time, which can be expensive for long sequences. In [Mar10b], a faster method, known as ASOS 
(approximate second order statistics), is proposed. In this approach, various statistics are precomputed 
in a single pass over the sequence, and from then on, all iterations take constant time (independent 
of T). Alternatively, if we have multiple processors, we can perform Kalman smoothing in O(log T) 
time using parallel scan operations (Section 8.3.3.3). 


29.8.2 Subspace identification methods 


EM does not always give satisfactory results, because it is sensitive to the initial parameter estimates. 
One way to avoid this is to use a different approach known as a subspace identification (SSID) 
[OM96; Kat05]. 

To understand this approach, let us initially assume there is no observation noise and no system 
noise. In this case, we have z; = Fz;,_; and y; = Hz, and hence y, = HF’~!z,. Consequently all 
the observations must be generated from a dim(z;)-dimensional linear manifold or subspace. We can 
identify this subspace using PCA. Once we have an estimate of the z,’s, we can fit the model as if it 
were fully observed. We can either use these estimates in their own right, or use them to initialize 
EM. Several papers (e.g., [Smi+00; BK15]) have shown that initializing EM this way gives much 
better results than initializing EM at random, or just using SSID without EM. 

Although the theory only works for noise-free data, we can try to estimate the system noise 
covariance Q from the residuals in predicting z, from z,_,, and to estimate the observation noise 
covariance R from the residuals in predicting y: from z+. We can either use these estimates in their 
own right, or use them to initialize EM. Because this method relies on taking an SVD, it is called a 
spectral estimation method. Similar methods can also be used for HMMs (see Section 29.4.3). 


29.8.3 Ensuring stability of the dynamical system 


When estimating the dynamics matrix F, it is very useful to impose a constraint on its eigenvalues. 
To see why this is important, consider the case of no system noise. In this case, the hidden state at 
time t is given by 


z = F’z; = UAU! z (29.109) 


where U is the matrix of eigenvectors for F, and A = diag(,;) contains the eigenvalues. If any A; > 1, 
then for large t, z; will blow up in magnitude. Consequently, to ensure stability, it is useful to require 
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that all the eigenvalues are less than 1 [SBGO7]. Of course, if all the eigenvalues are less than 1, then 
E [z,] = O for large t, so the state will return to the origin. Fortunately, when we add noise, the state 
becomes non-zero, so the model does not degenerate. 


29.8.4 Bayesian LDS 


SSMs can be quite sensitive to their parameter values, which is a particular concern when they are 
used for forecasting applications (see Section 29.12.1), or when the latent states or parameters are 
interpreted for scientific purposes (see e.g., [AM-+16]). In such cases, it is wise to represent our 
uncertainty about the parameters by using Bayesian inference. 

There are various algorithms we can use to perform this task. For linear-Gaussian SSMs, it 
is possible to use variational Bayes EM [Bea03; BC07] (see Section 10.2.5), or blocked Gibbs 
sampling (see Section 29.8.4.1). Note, however, that @ and z are highly correlated, so the mean field 
approximation can be inaccurate, and the blocked Gibbs method can mix slowly. It is also possible 
to use collapsed MCMC in which we marginalize out 2;.7 and just work with p(@|y;.7), which we 
can sample using HMC. 


29.8.4.1 Blocked Gibbs sampling for LDS 
This section is written by Xinglong Li. 


In this section, we discuss blocked Gibbs sampling for LDS [CK94b; CMR05; FSO7]. We alter- 
nate between sampling from p(z1.r|y1:7, 0) using the forwards-filter backwards-sampling algorithm 
(Section 8.2.8), and sampling from p(@|z1-7, y1-r), which is easy to do if we use conjugate priors. 

In more detail, we will consider the following linear Gaussian state space model with homogeneous 
parameters: 


P(2t|Zt-1, Us) = N (z|F 21-1 + Buy, Q) (29.110) 
P(Yr|Zt, Ue) = N (y|Hz: + Dur, R) (29.111) 


30 The set of all the parameters is 0 = {F, H, B, D, Q, R}. For the sake of simplicity, we assume that 
31 the regression coefficient matrix B and D include the intercept term (i.e., the last entry of u = 1). 


We use conjugate MNIW priors, as in Bayesian multivariate linear regression Section 15.2.8. 


33 Specifically, 


P(Q, [F, B]) = MNIW (M0, V 20, v40, Lao) (29.112) 
p(R, [H, D)) = MNIW (M, o, Vy0; Vro, W,0) (29.113) 
~ Given 21-7, U1.7 and y}.7, the posteriors are also MNIW. Specifically, 
Qlzi:r, urr ~ IW (vq; Yq) (29.114) 
[F, B] |Q, 21.7, ur ~ MN (M21, Q, V21) (29.115) 


A Je Jẹ Je Jẹ Je Je Jè 


where the set of hyper parameters {M31, V21, Vq, Uqi} of the posterior MNIW can be obtained by 


43 replacing Y, A,X in Equation (15.100) with z2:r, [F, B], and [z/_,,u}]/_..7, respectively. Similarly, 


R|zi:T, urr, Yr ~ IW (vri, P71) (29.116) 
H, D]|R, 21:7, urr, yır ~ MN (M,R, Vy1), (29.117) 
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29.9. SWITCHING LINEAR DYNAMICAL SYSTEMS (SLDS) 


Ut-1 Ut 
Mt—1 Me 
O 
Zt—1 
Yi-ı Yr 
(a) (e) 


Figure 29.25: (a) A switching SSM. Squares represent discrete random variables, circles represent continuous 
random variables. (b) Illustration of how the number of modes in the belief state of a switching SSM grows 
exponentially over time. We assume there are two binary states. 


and the hyper parameters {M,1, Vy1, ¥r1, Uri} of the posterior MNIW can be obtained by replacing 
Y, A,X in Equation (15.100) with y1ı:r, [H, D], and [y], uf]}.r. 


29.9 Switching linear dynamical systems (SLDS) 


Consider a state-space modelin which the latent state has both a discrete latent variable, m, € 
{1,..., K}, and a continuous latent variable, z, € R=. (A model with discrete and continuous 
latent variables is known as a hybrid system in control theory.) We assume the observed responses 
are continuous, y, € R^». We may also have continuous observed inputs us € RN«. The discrete 
variable can be used to represent different kinds of system dynamics or operating regimes (e.g., 
normal or abnormal), or different kinds of observation models (e.g., to handle outliers due to sensor 
noise or failures). If the system is linear-Gaussian, it is called a switching linear dynamical 
system (SLDS), a regime switching Markov model [Ham90; KN98], or a jump Markov 
linear system (JMLS) [DGKO1]. 


29.9.1 Parameterization 


An SLDS model is defined as follows: 


p(m = km1 = j) = Ajk (29.118) 
Pp(zi|zi-1, m = k, ut) = N (zi|Ekzi-1 + Brut + bk, Qx) (29.119) 
P(Y zt mi = k, ut) = N(yr|Hi 22 + Drui + dr, Ry) (29.120) 


See Figure 29.25a for the DPGM representation. It is straightforward to make a nonlinear version 
of this model. 
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29.9.2 Posterior inference 


Unfortunately exact inference in such switching models is intractable, even in the linear Gaussian 
case. To see why, suppose for simplicity that the latent discrete switching variable m; is binary, and 
that only the dynamics matrix F depend on mų, not the observation matrix H. Our initial belief 
state will be a mixture of 2 Gaussians, corresponding to p(z1|yi1,™mi = 1) and p(zi|y1,m1 = 2). 
The one-step-ahead predictive density will be a mixture of 4 Gaussians p(z2|y1,m1 = 1,m2 = 1), 
p(Z2|y1,™1 = 1, m2 = 2), p(zaly1,mi = 2,m2 = 1), and p(z2|y1; Mı = 2, M2 = 2), obtained by 
passing each of the prior modes through the 2 possible transition models. The belief state at step 2 
will also be a mixture of 4 Gaussians, obtained by updating each of the above distributions with 
Y2. At step 3, the belief state will be a mixture of 8 Gaussians. And so on. So we see there is an 
exponential explosion in the number of modes. Each sequence of discrete values corresponds to a 
different hypothesis (sometimes called a track), which can be represented as a tree, as shown in 
Figure 29.25b. 

Various methods for approximate online inference have been proposed for this model, such as the 
following: 


e Prune off low probability trajectories in the discrete tree. This is widely used in multiple hypothesis 
tracking methods (see Section 29.9.3). 


e Use particle filtering (Section 13.2) where we sample discrete trajectories, and apply the Kalman 
filter to the continuous variables. See Section 13.4.1 for details. 


e Use ADF (Section 8.9), where we approximate the exponentially large mixture of Gaussians with 
a smaller mixture of Gaussians. See Section 8.9.2 for details. 


e Use structured variational inference, where we approximate the posterior as a product of chain- 
structured distributions, one over the discrete variables and one over the continuous variables, 
with variational “coupling” terms in between (see e.g., [GH98; PJD21; Wan-+22]. 


30 29.9.3 Application: Multi-target tracking 


35 Lhe problem of multi-target tracking frequently arises in engineering applications (especially in 


aerospace and defence). This is a very large topic (see e.g. [BSF88; BSL93; Vo+15] for details), but 
in this section, we show how switching LDS models (or their nonlinear extensions) can be used to 
tackle the problem. 


29.9.3.1 Warmup 


= In the simplest setting, we know there are N objects we want to track, and each one generates its 
== own uniquely identified observation. If we assume the objects are independent, we can apply Kalman 
= filtering and smoothing in parallel, as shown in Figure 29.26. (In this example, each object follows a 
= linear dynamical model with different initial random velocities, as in Section 29.7.1.) 


= 29.9.3.2 Data association 


45 More generally, at each step we may observe M measurements e.g., “blips” on a radar screen. We 
46 can have M < N due to occlusion or missed detections. We can have M > N due to clutter or false 
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29.9. SWITCHING LINEAR DYNAMICAL SYSTEMS (SLDS) 


Data 


Filtered Posterior 


Smoothed Posterior 


e Trajectory 1 $ 
* Trajectory 2 > 
e Trajectory 3 
e Trajectory 4 
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Trajectory 3 
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44 — Trajectory 1 
— Trajectory 2 
2 4 — Trajectory 3 
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75 100 125 15.0 17.5 20.0 225 25.0 27.5 


(a) 
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25 


10 


Figure 29.26: Illustration of Kalman filtering and smoothing for tracking multiple moving objects. Generated 
by kf_ parallel. ipynb. 


Figure 29.27: A model for tracking two objects in the presence of data-association ambiguity. We observe 3, 1 
and 2 detections at time steps t — 1, t andt+ 1. The m, hidden variable encodes the association between the 
observations and the hidden causes. 


alarms. Or we can have M = N. In any case, we need to figure out the correspondence between 
the M detections x?” and the N objects zt. This is called the problem of data association, and it 
arises in many application domains. 

We can model this problem by augmenting the state space with discrete variables m; that represent 
the association matrix between the observations, y;,1,.17, and the sources, 2:,1:v. See Figure 29.27 for 
an illustration, where we have N = 2 objects, but a variable number M; of observations per time 
step. 

As we mentioned in Section 29.9.2, inference in such hybrid (discrete-continouus) models is 
intractable, due to the exponential number of posterior modes. In the sections below, we briefly 
mention a few approximate inference methods. 
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29.9.3.3 Nearest neighbor approximation using Hungarian algorithm 


A common way to perform approximate inference in this model is to compute an N x M weight 
matrix, where Wim measures the “compatibility” between object i and measurement m, typically 
based on how close m is to where the model thinks 7 is (the so-called nearest neighbor data 
association heuristic). 

We can make this into a square matrix by adding dummy background objects, which can explain 
all the false alarms, and adding dummy observations, which can explain all the missed detections. 
We can then compute the maximal weight bipartite matching using the Hungarian algorithm, 
which takes O(max(N, M)°) time (see e.g., [BDM09]). 

Conditional on knowing the assignments of measurements to tracks, we can perform the usual 
Bayesian state update procedure (e.g., based on Kalman filtering). Note that objects that are assigned 
to dummy observations do not perform a measurement update, so their state estimate is just based 
on forwards prediction from the dynamics model. 


29.9.3.4 Other approximate inference schemes 


The Hungarian algorithm can be slow (since it is cubic in the number of measurements), and can give 
poor results since it relies on hard assignment. Better performance can be obtained by using loopy 
belief propagation (Section 9.3). The basic idea is to approximately marginalize out the unknown 
assignment variables, rather than perform a MAP estimate. This is known as the SPADA method 
(sum-product algorithm for data association) [WL14b; Mey-+ 18]. 

The cost of each iteration of the iterative procedure is O(NM). Furthermore, [WL14b] proved this 
will always converge in a finite number of steps, and [Von13] showed that the corresponding solution 
will in fact be the global optimum. The SPADA method is more efficient, and more accurate, than 


26 earlier heuristic methods, such as JPDA (joint probabilistic data association) [BSWT11; Vo+15]. 


It is also possible to use sequential Monte Carlo methods to solve data association and tracking. 


28 See Section 13.2 for a general discussion of SMC, and [RAG04; Wan-+17b] for a review of specific 
29 techniques for this model family. 


— 29.9.3.5 Handling an unknown number of targets 


33 In general, we do not know the true number of targets N, so we have to deal with variable-sized 
34 state space. This is an example of an open world model (see Section 4.6.5), which differs from the 
35 standard closed world assumption where we know how many objects of interest there are. 


For example, suppose at each time step we get two “blips” on our radar screen, representing the 


37 presence of an object at a given location. These measurements are not tagged with the source of the 
38 object that generated them, so the data looks like Figure 29.28(a). In Figure 29.28(b-c) we show two 
39 different hypotheses about the underlying object trajectories that could have generated this data. 
40 However, how can we know there are two objects? Maybe there are more, but some are just not 
41 detected. Maybe there are fewer, and some observations are false alarms due to background clutter. 
42 One such more complex hypothesis is shown in Figure 29.28(d). Figuring out what is going on in 
43 problems such as this is known as multiple hypothesis tracking. 


A common approximate solution to this is to create new objects whenever an observation cannot be 


45 “explained” (i.e., generated with high likelihood) by any existing objects, and to prune out old objects 
46 that have not been detected in a while (in order to keep the computational cost bounded). Sets 
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Figure 29.28: Illustration of multi-target tacking in 2d over 5 time steps. (a) We observe 2 measurements per 
time step. (b-c) Possible hypotheses about the underlying object tracks. (d) A more complex hypothesis in 
which the red track stops at step 3, the dashed red track starts at step 4 the dotted blue track has a detection 
failure at step 3, and one of the measurements at step 3 is a false alarm. Adapted from Figure 15.8 of [RN19]. 


whose size and content are both random are called random finite sets. An elegant mathematical 
framework for dealing with such objects is described in [Mah07; Mah13; Vo+15]. 


29.10 Nonlinear SSMs 


In this section, we consider SSMs with nonlinear transition and/or observation functions, and additive 
Gaussian noise. That is, we assume the model has the following form 


zt = f (21-1, Ur) + q (29.121) 
qı ~ N(0, Q:) (29.122) 
Yt = (zt, ut) +r: (29.123) 
r, ~ N (0, Ri) (29.124) 


This is called a nonlinear dynamical system (NLDS), or nonlinear Gaussian SSM (NLG- 
SSM). 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Jo Ie lo IN Ie 


IR IE IS IS IR AS Ie Ie IR le lale le Is IE ls 


976 


Figure 29.29: Illustration of a bearings-only tracking problem. Adapted from Figure 2.1 of [CP20bj. 


29.10.1 Example: object tracking and state estimation 


In Section 8.5.4.1 we give an example of a 2d tracking problem where the motion model is nonlinear, 
but the observation model is linear. 

Here we consider an example where the motion model is linear, but the observation model is 
nonlinear. In particular, suppose we use the same 2d linear dynamics as in Section 29.7.1, where the 
state space contains the position and velocity of the object, z; = (ue Ut Ut ù). (We use u and 
v for the two coordinates, to avoid confusion with the state and observation variables.) Instead of 


26 directly observing the location, suppose we have a bearings only tracking problem, in which we 


27 just observe the angle to the target: 
y, = tan”! (==) +r (29.125) 
Ut — Sr 


32 where (sz, Sy) is the position of the measurement sensor. See Figure 29.29 for an illustration. This 
33 nonlinear observation model prevents the use of the Kalman filter, but we can still apply approximate 
34 inference methods, as we discuss below. 


29.10.2 Posterior inference 


=> Inferring the states of an NLDS model is in general computationally difficult. Fortunately, there are 
== a variety of approximate inference schemes that can be used, such as the extended Kalman filter 
= (Section 8.5.2), the unscented Kalman filter (Section 8.6.2), etc. 


43 29.11 Non-Gaussian SSMs 


45 In this section, we consider SSMs in which the transition and observation noise is non-Gaussian. 
46 The transition and observation functions can be linear or nonlinear. We can represent this as a 
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Figure 29.90: Samples from a 2d LDS with 5 Poisson likelihood terms. Generated by pois- 
son_lds_exrample.ipynb. 
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Figure 29.31: Latent state trajectory (blue lines) and dynamics matrix A (arrows) for (left) true model and 
(right) estimated model. The star marks the start of the trajectory. Generated by poisson_lds_ example.ipynb. 


probabilistic model as follows: 


p(Zt|Ze-1, Ut) = p(Zt| f (Ze-1, Ue)) (29.126) 
P(YrlZe, Ur) = P(ys|h(Z, Ue) (29.127) 


This is called a non-Gaussian SSM (NSSM). 
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29.11.1 Example: Spike train modeling 


In this section we discuss consider an SSM with linear-Gaussian latent dynamics and a Poisson 
likelihood. Such models are widely used in neuroscience for modeling neural spike trains. (see 
e.g., [Pan+10; Mac+11]). This is an example of an exponential family state-space model (see 
e.g., [Vid99; Hel17]). 

We consider a simple example where the model has 2 continuous latent variables, and we set the 
dynamics matrix A to a random rotation matrix. The observation model has the form 


D 


p(ylze) = [J Poi(yal exp(w)21)) (29.128) 
d=1 


where wg is a random vector, and we use D = 5 observations per time step. Some samples from this 
model are shown in Figure 29.30. 

We can fit this model by using EM, where in the E step we approximate p(y:|z+) using a Laplace 
approximation, after which we can use the Kalman smoother to compute p(z1.r|y1:r). In the M step, 
we optimize the expected complete data log likelihood, similar to Section 29.8.1. We show the result 
in Figure 29.31, where we compare the parameters A and the posterior trajectory E [z;|y1.r] using 
the true model and the estimated model. We see good agreement. 


29.11.2 Example: Stochastic volatility models 


In finance, it is common to model the the log-returns, ys = log(p;/p:-1), where p; is the price of 
some asset at time t. A common model for this problem, known as a stochastic volatility model, 


= (see e.g., [KSC98]), has the following form: 


yı = ul B + exp(z;/2)r; (29.129) 
zt = wt pla- — Bw) ton (29.130) 
r, ~ N (0, 1) (29.131) 
a ~ N (0, 1) (29.132) 


34 We see that the dynamical model is a first-order autoregressive process. We typically require that 
35 |p| < 1, to ensure the system is stationary. The observation model is Gaussian, but can be replaced 
36 by a heavy-tailed distribution such as a Student. 


We can capture longer range temporal correlation by using a higher order auto-regressive process. 


38 To do this, we just expand the state-space to contain the past K values. For example, if K = 2 we 


39 have 


Zt — H Pı P2 Zt—-1 7 H qt 
_ 4 29.133 
a = J ( 1 5) 65 i d (3) eas 
where qe ~ N (0, o2). Thus we have 


zt = H + pr(Zt-1 — H) + pol zt-2 — H) + 4 (29.134) 
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29.11.3 Posterior inference 


Inferring the states of an NGSSM model is in general computationally difficult. Fortunately, there 
are a variety of approximate inference schemes that can be used, which we discuss in Chapter 8 and 
Chapter 13. 


29.12 Structural time series models 


In this section, we discuss time series forecasting, which is the problem of computing the predictive 
distribution over future observations given the data up until the present, i.e., computing p(yr+n|y14)- 
(The model may optionally be conditioned on known future inputs, to get p(yz+n|Y1-4, U1:t+h)-) There 
are many approaches to this problem (see e.g., [HA21]), but in this section, we focus on structural 
time series (STS) models, which are defined in terms of linear-Gaussian SSMs. 

Many classical time series methods, such as the ARMA (autoregressive moving average) method, 
can be represented as STS models (see e.g., [Har90; Sim06; CK07; Fra08; DK12; Sar13; PFW21; 
Tri21]). However, the STS approach has much more flexibility. For example, we can create nonlinear, 
non-Gaussian and even hierarchical extensions, as we discuss below. 


29.12.1 Introduction 


The basic idea of an STS model is to represent the observed scalar time series as a sum of C individual 
components: 


fH =f) + felt) +--+ folt) + e (29.135) 


where e ~ N (0,07). For example, we might have a seasonal component that causes the observed 
values to oscillate up and down, and a growth component, that causes the observed values to get 
larger over time. Each latent process f.(t) is modeled by a linear Gaussian state-space model, which 
(in this context) is also called a dynamic linear model (DLM). Since these are linear, we can 
combine them altogether into a single LG-SSM. In particular, in the case of scalar observations, the 
model has the form 


P(Zt|Zt-1, 8) = N(2|F 2-1, Q) (29.136) 
P(yil2e, 0) = N (yi |e + B' U1, 05) (29.137) 
where F and Q are block structured matrices, with one block per component. The vector H then 
adds up all the relevant pieces from each component to generate the overall mean. Note that the 
matrices F and H are fixed sparse matrices which can be derived from the form of the corresponding 
components of the model, as we discuss below. So the only model parameters are the variance 
terms, Q and On and the optional regression coefficients 3.4 We can generalize this to vector-valued 
observations as follows: 
p(2t|Zt-1, 8) = N (z|Fz:-1, Q) (29.138) 
P(yelZe, 9) = N (y|Hz: + Dur, R) (29.139) 


4. In the statistics community, the notation often [DK12], who write the dynamics as a; = Ttœt—1 + ctRen, and the 
observations as yt = Zia; + B' a: + Htet, where n, ~ N(0,1) and e ~ N(0, 1). 
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29.12.2 Structural building blocks 


In this section, we discuss the building blocks of common STS models. 


29.12.2.1 Local level model 


The simplest latent dynamical process is known as the local level model. It assumes the observations 
ye E R are generated by a Gaussian with (latent) mean ut, which evolves over time according to a 
random walk: 


Ye = Ht + Eyt Eyt ~ N (0,05) (29.140) 

Mt = Ht—1 F €Epu,to Eut ~N (0, o7) (29.141) 
We also assume pı ~ N(0,0%). Hence the latent mean at any future step has distribution 
ut ~ N (0, to2), so the variance grows with time. We can also use an autoregressive (AR) process, 


), so the 


2 
Cu 
> T=p? 


Ht = pht-1 + €n,4, where |p| < 1. This has the stationary distribution poo ~ N(0 
uncertainty grows to a finite asymptote instead of unboundedly. 


29.12.2.2 Local linear model 


Many time series exhibit linear trends upwards or downwards, at least locally. We can model this 
by letting the level u; change by an amount 6;_1 (representing the slope of the line over an interval 
At = 1) at each step 


Ut = Mt-1 + 64-1 + Ept (29.142) 
The slope itself also follows a random walk, 
ôt = Ôt—1 + 5, (29.143) 


and €5, ~ N (0,02). This is called a local linear trend model. 
We can combine these two processes by defining the following dynamics model: 


Me) (1 1\ (Hi Eut 
(s) p ({ i) te E (o (29.144) 
——” = a N 

Zt Zt—1 E 


For the emission model we have 


y=(1 0) ‘a +€y,t (29.145) 
—-_4 t 
H “~~ 
We can also use an autoregressive model for the slope, i.e., 
Ôt =D + plôt—1 = D) + €65,t (29.146) 


where D is the long run slope to which 6 will revert. This is called a “semilocal linear trend” 


46 model, and is useful for longer term forecasts. 
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29.12. STRUCTURAL TIME SERIES MODELS 
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Figure 29.32: (a) A BSTS model with local linear trend and linear regression on inputs. The observed output 
is yt. The latent state vector is defined by z; = (ut, 64). The (static) parameters are 0 = (oy, 0u, 05, B). The 
covariates are uz. (b) Adding a latent seasonal process (with S = 4 seasons). Parameter nodes are omitted 
for clarity. 


29.12.2.3 Adding covariates 


We can easily include covariates u, into the model, to increase prediction accuracy. If we use a linear 
model, we have 


Ye = p + B" us + ey (29.147) 


See Figure 29.32a for an illustration of the local level model with covariates. Note that, when 
forecasting into the future, we will need some way to predict the input values of future ut+h; a simple 
approach is just to assume future inputs are the same as the present, Utyh = Ut. 


29.12.2.4 Modelling seasonality 


Many time series also exhibit seasonality, i.e., they fluctuate periodically. This can be modeled by 
creating a latent process consisting of a series offset terms, s;. To model cyclicity, we ensure that 
these sum to zero (on average) over a complete cycle of S steps: 


S-1 

St =— J Stat eat, Est ~N(0,05) (29.148) 
k=1 

For example, for S = 4, we have s; = —(s:-1 + St—2 + 5¢-3) +€s,4. We can convert this to a first-order 


model by stacking the last S — 1 seasons into the state vector, as shown in Figure 29.32b. 
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29.12.2.5 Adding it all up 


We can combine the various latent processes (local level, linear trend, and seasonal cycles) into a 
single linear-Gaussian SSM, because the sparse graph structure can be encoded by sparse matrices. 
More precisely, the transition model becomes 


sı —1 —1 —1 0 0). 7 ee 
S 1 0 0 0 O| s2 
s2}=|0 1 0 0 Of | s3 | +V (0, diag([o2,0, 0,07, 05]) (29.149) 
Ht 0 0 0 1 1 Ht—-1 
Or 0o 0 0 0 1) \&a 
— a 
Zt F Zt—1 


Having defined the model, we can use the Kalman filter to compute p(z:|Y1:+), and then make 
predictions forwards in time by rolling forward in latent space, and then predicting the outputs: 


P(Yilyrt) = | Polzele ply)dz (29.150) 


This can be computed in closed form, as explained in Section 8.3.2. 


29.12.3 Model fitting 


Once we have specified the form of the model, we need to learn the model parameters, 9 = (D, R, Q), 
since F and H fixed to the values specified by the structural blocks, and B = 0. Common 
approaches are based on maximum likelihood estimation (see Section 29.8), and Bayesian inference 
(see Section 29.8.4). The latter approach is known as Bayesian structural time series or BSTS 
modeling [SV14; QRJN18], and often uses the following conjugate prior: 


p(9) = MNIW(R, D)IW(Q) (29.151) 


Alternatively, if there are a large number of covariates, we may use a sparsity-promoting prior (e.g., 
spike and slab, Section 15.2.4) for the regression coefficients D. 


29.12.4 Forecasting 


Once the parameters have been estimated on an historical dataset, we can perform inference on 
a new time series to compute p(zt|Y1:t, U1:+,0) using the Kalman filter (Section 8.3.2). Given the 
current posterior, we can then “roll forward” in time to forecast future observations h steps ahead 
by computing pP(Yt+hl|Y1:t, Ur:t+h; 0), as in Section 8.3.2.3. If the parameters are uncertain, we can 
sample from the posterior, p(0|Y1:t, U1:+), and then perform Monte Carlo averaging of the forecasts. 


29.12.5 Examples 


46 In this section, we give various examples of STS models. 
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Figure 29.38: (a) CO2 levels from Mauna Loa. In orange plot we show predictions for the most recent 10 
years. (b) Underlying components for the STS mode which was fit to Figure 29.33a. Generated by sts.ipynb. 


29.12.5.1 Example: forecasting CO2 levels from Mauna Loa 


In this section, we fit an STS model to the monthly atmospheric CO2 readings from the Mauna Loa 
observatory in Hawaii.” The data is from January 1966 to February 2019. We combine a local linear 
trend model with a seasonal model, where we assume the periodicity is S = 12, since the data is 
monthly (see Figure 29.33a). We fit the model to all the data except for the last 10 years using 
variational Bayes. The resulting posterior mean and standard deviations for the parameters are 
dy = 0.169 + 0.008, o, = 0.159 + 0.017, a5 = 0.009 + 0.003, as = 0.038 + 0.008. We can sample 10 
parameter vectors from the posterior and then plug them it to create a distribution over forecasts. 
The results are shown in orange in Figure 29.33a. Finally, in Figure 29.33b, we plot the posterior 
mean values of the two latent components (linear trend and current seasonal value) over time. We 
see how the model has successfully decomposed the observed signal into a sum of two simpler signals. 
(See also Main Section 18.8.1 where we model this data using a GP.) 


29.12.5.2 Example: forecasting (real-valued) electricity usage 


In this section, we consider a more complex example: forecasting electricity demand in Victoria, 
Australia, as a function of the previous value and the external temperature. (Remember that January 
is summer in Australia!) The hourly data from the first six weeks of 2014 is shown in Figure 29.34a.° 

We fit an STS to this using 4 components: a seasonal hourly effect (period 24), a seasonal daily effect 
(period 7, with 24 steps per season), a linear regression on the temperature, and an autoregressive 
term on the observations themselves. We fit the model with variational inference. (This takes about 
a minute on a GPU.) We then draw 10 posterior samples and show the posterior predictive forecasts 
in Section 29.12.5.2. We see that the results are reasonable, but there is also considerable uncertainty. 

We plot the individual components in Figure 29.35. Note that they have different vertical scales, 
reflecting their relative importance. We see that the regression on the external temperature is the 
most important effect. However, the hour of day effect is also quite significant, even after accounting 
for external temperature. The autoregressive effect is the most uncertain one, since it is responsible 


5. For details, see https: //blog.tensorflow. org/2019/03/structural-time-series-modeling-in.html. 
6. The data is from https: //github.com/robjhyndman/fpp2- package. 
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Figure 29.34: (a) Hourly temperature and electricity demand in Victoria, Australia in 2014. (b) Electricity 
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Figure 29.85: Components of the electricity forecasts. Generated by sts.ipynb. 
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Figure 29.36: Anomaly detection in a time series. We plot the observed electricity data in blue and the 
predictions in orange. In gray, we plot the z-score at time t , given by (ye — ut) /ot, where plyt|Yi:t-1, Ur:t) = 
N (ut, oF). Anomalous observations are defined as points where zı > 3 and are marked with red crosses. 
Generated by sts.ipynb. 


for modeling all of the residual variation in the data beyond what is accounted for by the observation 
noise. 

We can also use the model for anomaly detection. To do this, we compute the one-step-ahead 
predictive distributions, p(yt|Y1:t—1, U1:+), for each timestep t, and then flag all timesteps where the 
observation is improbable. The results are shown in Figure 29.36. 


29.12.5.3 Example: forecasting (integer valued) sales 


In Section 29.12.5.2, we used a linear Gaussian STS model to forecast electricity demand. However, 
for some problems, we have integer valued observations, e.g., for neural spike data (see Main Sec- 
tion 29.11.1), RNA-Seq data [LJY19], sales data, etc. Here we focus on the case of sales data, where 
yz € {0,1,2,...} is the number of units of some item that are sold on a given day. Predicting future 
values of y; is important for many businesses. (This problem is known as demand forecasting.) 

We assume the observed counts are due to some latent demand, z, € R. Hence we can use a model 
similar to Main Section 29.11.1, with a Poisson likelihood, except the linear dynamics are given 
by an STS model. In [SSF16; See+17], they consider a likelihood of the form yg ~ Poi(y;|g(d?)), 
where d4 = 4% + ulw is the instantaneous latent demand, u, are the covariates that encode seasonal 
indicators (e.g., temporal distance from holidays), and g(d) = e? or log(1 +e?) is the transfer function. 
The dynamics model is a local random walk term, and z = % 1 + aNV(0,1), to capture serial 
correlation in the data. 

However, sometimes we observe zero counts, y¢ = 0, not because there is no demand, but because 
there is no supply (i.e., we are out of stock). If we do not model this properly, we may incorrectly 
infer that z; = 0, thus underestimating demand, which may result in not ordering enough inventory 
for the future, further compounding the error. 

One solution is to use a zero-inflated Poisson (ZIP) model [Lam92] for the likelihood. This is 
a mixture model of the form p(y:|d:) = pol (ye = 0) + (1 — po)Poi(ys|e*), where po is the probability 
of the first mixture component. It is also common to use a (possibly zero-inflated) negative binomial 
model (Section 2.2.1.4) as the likelihood. This is used in [Chal4; Sal+19b] for the demand forecasting 
problem. The disadvantage of these likelihoods is that they are not log-concave for dg = 0, which 
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Figure 29.37: Visualization of a probabilistic demand forecast for a hypothetical product. Note the sudden 
spike near the Christmas holiday in December 2013. The black line denotes the actual demand. Green lines 
denote the model samples in the training range, while the red lines show the actual probabilistic forecast on 
data unseen by the model. The red bars at the bottom indicate out-of-stock events which can explain the 
observed zeros. From Figure 1 of [Bös+17]. Used with kind permission of Tim Januschowski. 


Figure 29.38: Illustration of a hierarchical state-space model. 


22 complicates posterior inference. In particular, the Laplace approximation is a poor choice, since it 
22 may find a saddle point. In [SSF 16], they tackle this using a log-concave multi-stage likelihood, 
34 in which y = 0 is emitted with probability o(d?); otherwise y+ = 1 is emitted with probability (dt); 
33 otherwise y; = is emitted with probability Poi(d?). This generalizes the scheme in [SOB12]. 


= 29.12.5.4 Example: hierarchical SSM for electoral panel data 


40 Suppose we perform a survey for the US presidential elections. Let NÍ be the number of people who 
41 vote at time ¢ in state j, and let y? be the number of those people who vote Democrat. (We assume 


NÍ = yl vote Republican.) It is natural to want to model the dependencies in this data both across 


43 time (longitudinally) and across space (this is an example of panel data). 


We can do this using a hierarchical SSM, as illustrated in Figure 29.38. The top level Markov 


45 chain, 29, models national-level trends, and the state-specific chains, zi , model local “random effects”. 
46 In practice we would usually also include covariates at the national level, u? and state level, už. 
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29.12. STRUCTURAL TIME SERIES MODELS 


Figure 29.389: Twin network state-space model for estimating causal impact of an intervention that occurs 
just after time step n = 2. We have m = 4 actual observations, denoted y1:4. We cut the incoming arcs 
to z3 since we assume zs:r comes from a different distribution, namely the post-intervention distribution. 
However, in the counterfactual world, shown at the bottom of the figure (with tilde symbols), we assume the 
distributions are the same as in the past, so information flows along the chain uninterrupted. 


Thus the model becomes 
yi ~ Bin(yf|n?, NÉ) 29.152 

mi = o (2?) u? + (ef) af 

zp = zł_ı +N (0, 0°T) 

= za ug N (0, 7°I) 


29.153 


29.154 
29.155 


—~ aa 
Sart N RR 


For more details, see [Lin13b]. 


29.12.6 Causal impact of a time series intervention 


In this section, we discuss how to perform counterfactual reasoning about the effect on an 
intervention given some observational (non experimental) time series data. (We discuss counterfactuals 
in more detail in Main Section 4.7.4.) For example, suppose y; is the click through rate (CTR) of 
the web page of some company at time t. The company launches an ad campaign at time n, and 
observes outcomes ŅY1:n before the intervention and Yn+1:m after the intervention. A natural question 
to ask is: what would the CTR have been had the company not run the ad campaign? 

To answer this question, we will use a structural time series (STS) model (see Section 29.12 for 
details). An STS model is a linear-Gaussian state-space model, where arrows have a natural causal 
interpretation in terms of the arrow of time; thus a STS is a kind of structural equation model, 
and hence a structural causal model (see Main Section 4.7). The use of an SCM allows us to infer 
the latent state of the noise variables given the observed data; we can then “roll back time” to the 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


IO 100 IN ID Jo [B® lo IN Ie 


S lè IN S X IE 18 IS IS IS le Ie IS IS la le le Is lE Is 


A Je je PR pe JR JR [BR Joo fw joo fw foo Jw jw jw jo joo 
IS J8 IÀ IRIE I Ie I IS IB IS IS IS Ik 18 18 le IS | 


988 


point of intervention, where we explore an alternative “fork in the road” from the one we actually 
took by “rolling forward in time” in a new version of the model, using the twin network approach 
to counterfactual inference (see Main Section 4.7.4). This approach is known as “causal impact”, 
and was developed by econometricians at Google [Bro+15]. 


29.12.6.1 Basic idea 


To explain the idea in more detail, consider the twin network in Figure 29.39. The intervention 
occurs after time n = 2, and there are m = 4 observations in total. We observe 2 data points before 
the intervention, y1:2, and 2 data points afterwards, y3.4. We assume observations are generated by 
latent states z1.4, which evolve over time. The states are subject to exogeneous noise terms, which 
can represent any set of unmodeled factors, such as the state of the economy. In addition, we have 
exogeneous covariates, L1-m- 

To predict what would have happened if we had not performed the intervention, (an event denoted 
by à = 0), we replicate the part of the model that occurs after the intervention, and use it to make 
forecasts. The goal is to compute the counterfactual distribution, p(Yn+i:m|Yien; Liem), where Je 
represents counterfactual outcomes if the action had been & = 0. We can compute this counterfactual 
distribution as follows: 


P(YntiemlYin; Tim) = | P@rtimlEnsim, Ln+1:ms 9)p(Znti:mlZn; 6) x (29.156) 
D(Zn, O| Zien, Yin )dOdzndZntirm (29.157) 

where 
P(Zn, O|Ti:n, Yin) = pea \ Piss Yin; O)p(O|x1-n, Yin) (29.158) 


For linear Gaussian SSMs, the term p(Zn|1:n, Yin,9) can be computed using Kalman filtering 


31 (Main Section 8.3.2), and the term p(O|Y1:n, Zi:n), can be computed using MCMC or variational 
32 inference. 


We can use samples from the above posterior predictive distribution to compute a Monte Carlo 


34 approximation to the distribution of the treatment effect per time step, T = y, — Ji, where the 
35 7 index refers to posterior samples. We can also approximate the distribution of the cumulative 
36 causal impact using of = $, 41 Ti. (There will be uncertainty in these quantities arising both from 
37 epistemic uncertainty, about the true parameters controlling the model, and aleatoric uncertainty, 
38 due to system and observation noise.) 


The validity of the method is based on 3 assumptions: (1) Predictability: we assume that the 


40 outcome can be adequately predicted by our model given the data at hand. (We can check this by 
41 using backcasting, in which we make predictions on part of the historical data.) (2) Unaffectedness: 
42 we assume that the intervention does not change future covariates £n+1:m. (We can potentially check 
43 this by running the method with each of the covariates as an outcome variable.) (3) Stability: we 


assume that, had the intervention not taken place, the model for the outcome in the pre-treatment 


45 period would have continued in the post-treatment period. (We can check this by seeing if we predict 
46 an effect if the treatment is shifted earlier in time.) 
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Figure 29.40: A graphical model representation of the local level causal impact model. The dotted line 
represents the time n at which an intervention occurs. Adapted from Figure 2 of [Bro+15]. Used with kind 
permission of Kay Brodersen. 


29.12.6.2 Example 


As a concrete example, let us assume we have a local level model and we use linear regression to 
model the dependence on the covariates, as in Section 29.12.2.3. That is, 


yi = p + B' a, +N (0, 02) (29.159) 
Ht = hi1 + de +N (0, 02) (29.160) 
ôs = 6:1 +N (0, 03) (29.161) 
See the graphical model in Figure 29.40. The static parameters of the model are 0 = (8, o$, oi, o2), 


the other terms are state or observation variables. (Note that we are free to use any kind of STS 
model; the local level model is just a simple default.) 

The use of a linear combination of other “donor” time series is similar in spirit to the concept 
of a “synthetic control” [Aba; Shi+21]. However we do not restrict ourselves to a convex combi- 
nation of donors. Furthermore, when we have many covariates, we can use a spike-and-slab prior 
(Main Section 15.2.4) or horseshoe prior (Main Section 15.2.6) to select the relevant ones. 

We now give a simple example using synthetic data. We create 3 random sequences of covariates, 
x, € R3, and define the observed output to be given by Equation (29.159), with regression yie 
B= (2, 3, 0). At time step n = 70, we perform an artificial intervention by adding A; to y+, where 
A, starts off at 5 and drops down to 0 over a period of 10 steps. This simulates the kind of transient 
lift one often sees when performing marketing campaigns. The data is shown in Figure 29.41(a). We 
see that there seems to be a small increase in the blue curve y at around time step 70, but it is hard 
to tell because of the noise. 

We fit a local level STS model using variational inference, and then make the counterfactual 
forecast shown in the top row of Figure 29.41(b). Now we see more clearly that, had the process 
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Figure 29.41: (a) Some simulated time series data which we use to estimate the causal impact of some 
intervention, which occurs at time n = 70, illustrated by the dotted line. The blue curve are the observed 
outcomes y, the other curves are covariates (inputs). (b) Output of causal inference. Top row: observed vs 


— predicted outcomes. Middle row: Estimate of causal effect % at each time step. Bottom row: Cumulative 


IS 16 à IÈ l& Is 


=< causal effect, o+, up to each time step. Generated by causal_ impact.ipynb. 
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29.12. STRUCTURAL TIME SERIES MODELS 


continued without the intervention, the counterfactual outcome would have been smaller. We can 
therefore estimate the instantaneous and cumulative causal impact, shown in the middle and bottom 
rows of Figure 29.41(b). Furthermore, we find that the posterior mean estimate for the regression 
coefficient is B = (1.4927632, 1.945029, —0.10319362). We see that the third input variable is mostly 
ignored, as desired. 


29.12.7 Prophet 


Prophet [TL 18a] is a popular time series forecasting library from Facebook. It fits a generalized 
additive model of the form 


y(t) = g(t) + s(t) + h(t) + walt) + e (29.162) 


where g(t) is a trend function, s(t) is a seasonal fluctuation (modeled using linear regression applied 
to a sinusoidal basis set), h(t) is an optional set of sparse “holiday effects”, z(t) are an optional set 
of (possibly lagged) covariates, w are the regression coefficients, and e(t) is the residual noise term, 
assumed to be iid Gaussian. 

Prophet is a regression model, not an auto-regressive model, since it predicts the time series y1.7 
given the time stamp t and the covariates 41:7, but without conditioning on past observations of y. 
To model the dependence on time, the trend function is assumed to be a piecewise linear trend with 
S changepoints, uniformly spaced in time. (See Main Section 29.5.6 for a discussion of changepoint 
detection.) That is, the model has the form 


g(t) = (k + a(t)78)t + (m + a(t)"y) (29.163) 


where k is the growth rate, m is the offset, aj(t) = I(t > sj), where sj is the time of the j’th 
changepoint, 6; ~ Laplace(r) is the magnitude of the change, and yj = —s;6; to make the function 
continuous. The Laplace prior on 6 ensures the MAP parameter estimate is sparse, so the difference 
across change point boundaries is usually 0. 

For an interactive visualization of how prophet works, see http: //prophet .mbrouns.com/. 


29.12.8 Neural forecasting methods 


Classical time series methods work well when there is little data (e.g., short sequences, or few 
covariates). However, in some cases, we have a lot of data. For example, we might have a single, but 
very long sequence, such as in anomaly detection from real-time sensors [Ahm+17]. Or we may have 
multiple, related sequences, such as sales of related products [Sal+-19b]. In both cases, larger data 
means we can afford to fit more complex parametric models. Neural networks are a natural choice, 
because of their flexibility. Until recently, their performance in forecasting tasks was not competitive 
with classical methods, but this has recently started to change, as described in [Ben+20; LZ20]. 

A common benchmark in the univariate time series forecasting literature is the M4 forecasting 
competition [MSA18], which requires participants to make forecasts on many different kinds of 
(univariate) time series (without covariates). This was recently won by a neural method [Smy20]. 
More precisely, the winner of the 2019 M4 competition was a hybrid RNN-classical method called 
ES-RNN [Smy20]. The exponential smoothing (ES) part allows data-efficient adaptation to the 
observed past of the current time series; the recurrent neural network (RNN) part allows for learning 
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of nonlinear components from multiple related timeseries. (This is known as a local+global model, 
since the ES part is “trained” just on the local timeseries, whereas the RNN is trained on a global 
dataset of related time series.) 

In [Ran+18] they adopt a different approach for combining RNNs and classical methods, called 
DeepSSM. In particular, they train a single RNN to predict the parameters of a state-space model 
(see Main Section 29.1). In more detail, let x. represent the n’th time series, and let 6? represent 
the non-stationary parameters of a linear-trend SSM model (see Section 29.12.1). We train an RNN to 
compute 6? = f(ci.7;@), where @ are the RNN parameters shared across all sequences. We can use 
the predicted parameters to compute the log likelihood of the sequence, Ln = log p(x7.p|ci.r, 9.7), 
using the Kalman filter. These two modules can be combined to allow for end-to-end training of o 
to maximize > Ly. 

In [Wan+19c], they propose a different hybrid model known as Deep Factors. The idea is to 
represent each time series (or its latent function, for non-Gaussian data) as a weighted sum of a 
global time series, coming from a neural model, and a stochastic local model, such as an SSM or GP. 
The DeepGLO (global-local) approach of [SYD19] proposes a related hybrid method, where the 
global model uses matrix factorization to learn shared factors. This is then combined with temporal 
convolutional networks. 

It is also possible to train a purely neural model, without resorting to classical methods. For 
example, the N-BEATS model of [Ore+20] trains a residual network to predict the weights of a set 
of basis functions, corresponding to a polynomial trend and a periodic signal. The weights for the 
basis functions are predicted for each window of input using the neural network. Another approach 
is the DeepAR model of [Sal+19b], which fits a single RNN to a large number of time series. The 
original paper used integer (count) time series, modeled with a negative binomial likelihood function. 
This is a unimodal distribution, which may not be suitable for all tasks. More flexible forms, such as 
mixtures of Gaussians, have also been proposed [Muk-+18]. A popular alternative is to use quantile 
regression [Koe05], in which the model is trained to predict quantiles of the distribution. For 
example, [Gas+19] proposed SQF-RNN, which uses splines to represent the quantile function. 
They used CRPS or continuous-ranked probability score as the loss function. This is a proper 
scoring rule, but is less sensitive to outliers, and is more “distance aware”, than log loss. 

The above methods all predict a single output (per time step). If there are multiple simultaneous 
observations, it is best to try to model their interdependencies. In [Sal+ 19a], they use a (low-rank) 
Gaussian copula for this, and in [Tou+19], they use a nonparametric copula. 

In [Wen+17], they simultaneously predict quantiles for multiple steps ahead using dilated causal 
convolution (or an RNN). They call their method MQ-CNN. In [WT19], they extend this to predict 
the full quantile function, taking as input the desired quantile level a, rather than prespecifying 
a fixed set of levels. They also use a copula to learn the dependencies among multiple univariate 
marginals. 


41 29.13 Deep SSMs 


43 Traditional state-space model assume linear dynamics and linear observation models, both with 


additive Gaussian noise. This is obviously very limiting. In this section, we allow the dynamics 


45 and/or observation model to be modeled by nonlinear and/or non-Markovian deep neural networks; 
46 we call these deep state-space model, also known as dynamical variational autoencoders. 
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29.13. DEEP SSMS 


(To be consistent with the literature on VAEs, we denote the observations by æ+ instead of y+.) For a 
detailed review, see [Ged+20; Gir+21]. 


29.13.1 Deep Markov models 


Suppose we create a SSM in which we use a deep neural network for the dynamics model and/or 
observation model; the result is called a deep Markov model [KSS17] or stochastic RNN [BO14; 
Fra+16]. (This is not quite the same as a variational RNN, which we explain in Section 29.13.4.) 

We can fit a DMM using SVI (Section 10.3.1). The key is to infer the posterior over the latents. 
From the first-order Markov properties, the exact posterior is given by 

T T 
plzirlæir) = J [plalz er) = [| oder weet, TtT) (29.164) 
t=1 t=1 
where we define p(z1|zo, #1-r) = p(zı|£1:r), and the cancelation follows since z; L £1:t—-1|Zt—-1, as 
pointed out in [KSS17]. 

In general, it is intractable to compute p(z1.7|@1.7), so we approximate it with an inference network. 
There are many choices for q. A simple one is a fully factorized model, q(z1:r) = [], (2:|@14). This 
is illustrated in Figure 29.42a. Since z; only depends on past data, x1.4 (which is accumulated in the 
RNN hidden state h;), we can use this inference network at run time for online inference. However, 
for training the model offline, we can use a more accurate posterior by using 


T 
q(Z1:7|\£1:7) = II q(Z:|2t-1, 21-7) = II q(Zt|Zt-1, Litai, TtT) (29.165) 
t=1 t=1 

Note that the dependence on past observation £1:+—1 is already captured by 2-1, as in Equa- 
tion (29.164). The dependencies on future observations, x;.7, can be summarized by a backwards 
RNN, as shown in Figure 29.42b. Thus 


1 T 
q(21-7; hi-7|£21-7) = II I (hy = f(Ai+1, £4) [Ia +|21-1, ht) (29.166) 
t=T t=1 
Given a fully factored q(z1.7), we can compute the ELBO as follows. 
log p(a1.7) i log p pæren) (29.167) 
21:T 


E e (29.168) 


= 108 2 q(z1:7) [ærin E F) 


T 
= log Ey(2,.7) pees piala 2] (29.169) 


t=1 
T 
> Eir) È log p(x;|z+) + log p(z:|2:-1) — log azo} (29.170) 
t=1 
T 
= SE qz,) log p(aelze)] — Egiz, 1) [Dex (a(z) I p(2el2e—1))] (29.171) 
t=1 
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Figure 29.42: Inference networks for deep Markov model. (a) Fully factorized causal posterior q(z1:7|@1:7) = 
TI, a(2|v1:). The past observations x1. are stored in the RNN hidden state hi. (b) Markovian posterior 
q(21:7\@1:r) = |], a(z|Ze-1, £r). The future observations xır are stored in the RNN hidden state hs. 


Figure 29.43: Recurrent state-space models. (a) Prior is first-order Markov, p(z1|Z1-1), but observation model 


26 is not Markovian, p(x:|h+) = p(at|Z1:4), where hi summarizes z1:. (b) Prior model is no longer first-order 
27 Markov either, p(zt|he-1) = p(2t|Z1:t-1). Diamonds are deterministic nodes, circles are stochastic. 


If we assume that the variational posteriors are jointly Gaussian, we can use the reparameterization 


33 trick to use posterior samples to compute stochastic gradients of the ELBO. Furthermore, since we 
34 assumed a Gaussian prior, the KL term can be computed analytically. 


3l 29.13.2 Recurrent SSM 


In a DMM, the observation model p(æ+|z+) is first-order Markov, as is the dynamics model p(z;|Z4_1). 


40 We can modify the model so that it captures long-range dependencies by adding deterministic hidden 
41 states as well. We can make the observation model depend on 2.; instead of just z; by using p(x:|h+), 
42 where h; = f(hz-1, z+), so hy records all the stochastic choices. This is illustrated in Figure 29.43a. 
43 We can also make the dynamical prior depend on 21.41 by replacing p(z:|z:-1) with p(z,|hi-1). as 
44 is illustrated in Figure 29.43b. This is known as a recurrent SSM. 


We can derive an inference network for an RSSM similar to the one we used for DMMs, except 


46 now we use a standard forwards RNN to compute q(Z;|%1:4-1, £1). 
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29.13. DEEP SSMS 


(a) Standard Variational bound (b) Observation overshooting (c) Latent overshooting 


Figure 29.44: Unrolling schemes for SSMs. The labels zij is shorthand for p(zi|a@1:;). Solid lines denote the 
generative process, dashed lines the inference process. Arrows pointing at shaded circles represent log-likelihood 
loss terms. Wavy arrows indicate KL divergence loss terms. (a) Standard 1 step reconstruction of the 
observations. (b) Observation overshooting tries to predict future observations by unrolling in latent space. 
(c) Latent overshooting predicts future latent states and penalizes their KL divergence, but does need to care 
about future observations. Adapted from Figure 3 of [Haf+19]. 


29.13.3 Improving multi-step predictions 


In Figure 29.44(a), we show the loss terms involved in the ELBO. In particular, the wavy edge 2, > 
24-1 corresponds to Egz,_,) [Dx (¢(22) || p(z:|zt+-1))], and the solid edge z4 —> a; corresponds to 
talz) log p(az|2z)]. We see that the dynamics model, p(z;|Z:—1), is only ever penalized in terms of 
how it differs from the one-step-ahead posterior q(z+), which can hurt the ability of the model to 
make long-term predictions. 

One solution to this is to make multi-step forward predictions using the dynamics model, and use 
these to reconstruct future observations, and add these errors as extra loss terms. This is called 
observation overshooting [Amo+18], and is illustrated in Figure 29.44(b). 

A faster approach, proposed in [Haf+19], is to apply a similar idea but in latent space. More 
precisely, let us compute the multi-step prediction model, by repeatedly applying the transition 
model and integrating out the intermediate states to get p(z;|z:-a). We can then compute the ELBO 
for this as follows: 


T 
log paler) ê log | JJ r(esl2:-a)p(ailee)dzaer (29.172) 


t=1 


> J Esc) log p(aelze)] — Epes aler—ayalera) [Dux (a(t) Il p(2elze-1))] (29.173) 


t=1 


To train the model so it is good at predicting at different future horizon depths d, we can average 
the above over all 1 < d < D. However, for computational reasons, we can instead just average 
the KL terms, using weights 8g. This is called latent overshooting [Haf+19], and is illustrated in 
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Figure 29.45: Variational RNN. (a) Generative model. (b) Inference model. The diamond-shaped nodes are 
deterministic. 


Figure 29.44(c). The new objective becomes 


D T 
1 
D Slog pa(wir) > XC Eq(z,) log p(ae|z0)] (29.174) 
d=1 t=1 
1 D 
-5 XO BaEp(zra|ze—-a)a(ze—a) Pre (a(z) || p(2el2e-1)))] (29.175) 
d=1 


25 29.13.4 Variational RNNs 
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A variational RNN (VRNN) [Chu+15] is similar to a recurrent SSM except the hidden states are 
generated conditional on all past hidden states and all past observations, rather than just the past 
hidden states. This is a more expressive model, but is slower to use for forecasting, since unrolling 
into the future requires generating observations #441, ®442,... to “feed into” the hidden states, which 


31 controls the dynamics. This makes the model less useful for forecasting and model-based RL (see 


Section 35.4.5.2). 
More precisely, the generative model is as follows: 


T 
p(£1:T, 21:7, hi:r) = [[ dhe, x1) (he = f (hi—1, £t-1, 2t)) p(@1|he) (29.176) 


t=1 


38 Where p(21|ho, £0) = p(Zo) and hı = f(ho, £o, 21) = f(21). Thus hy = (21:4, £1:t—1) is a summary 


of the past observations and past and current stochastic latent samples. If we marginalize out 


40 these deterministic hidden nodes, we see that the dynamical prior on the stochastic latents is 


p(z4\he-1, £t-1) = p(Zz|Z1-4-1, 14-1), whereas in a DMM, it is p(z;|z:-1), and in an RSSM, it is 


42 p(214|Z14-1). See Figure 29.45a for an illustration. 


We can train VRNNs using SVI. In [Chu+15], they use the following inference network: 


T 
q(Z1:-7, hi:r|£rr) = ie (he = f (hi1, 2-1, £+) ¢(Zt| he) (29.177) 


i=l 
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Thus hy = (21:4-1,2%1:4). Marginalizing out these deterministic nodes, we see that the filtered 
posterior has the form q(21:r|x1.7) = [], ¢(2¢|214-1, £14). See Figure 29.45b for an illustration. (We 
can also optionally replace x; with the output of a bidirectional RNN to get the smoothed posterior, 
q(21:7|£1:7) = Ik q(zi|Z1:t-1; zar )-) 

This approach was used in [DF18] to generate simple videos of moving objects (e.g., a robot 
pushing a block); they call their method stochastic video generation or SVG. This was scaled 
up in [Vil+19], using simpler but larger architectures. 
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30 Graph learning 


30.1 Introduction 


Graphs are a very common way to represent data. In this chapter we discuss probability models for 
graphs. In Section 30.2, we assume the graph structure G is known, but we want to “explain” it in 
terms of a set of meaningful latent features; for this we use various kinds of latent variable models. In 
Section 30.3, we assume the graph structure G is unknown and needs to be inferred from correlated 
data, £n € RP; for this, we will use probabilistic graphical models with unknown topology. (See also 
Section 16.3.6, where we discuss graph neural networks, for performing supervised learning using 
graph-structured data.) 


30.2 Latent variable models for graphs 


Graphs arise in many application areas, such as modeling social networks, protein-protein interaction 
networks, or patterns of disease transmission between people or animals. There are usually two 
primary goals when analysing such data: first, try to discover some “interesting structure” in the 
graph, such as clusters or communities; second, try to predict which links might occur in the future 
(e.g., who will make friends with whom). In this section, we focus on the former. More precisely, we 
will consider a variety of latent variable models for observed graphs. 


30.2.1 Stochastic block model 


In Figure 30.1(a) we show a directed graph on 9 nodes. There is no apparent structure. However, 
if we look more deeply, we see it is possible to partition the nodes into three groups or blocks, 
B, = {1,4,6}, Bo = {2,3,5,8}, and Bs = {7,9}, such that most of the connections go from nodes in 
B, to Bo, or from By to B3, or from B; to Bı. This is illustrated in Figure 30.1(b). 

The problem is easier to understand if we plot the adjacency matrices. Figure 30.2(a) shows the 
matrix for the graph with the nodes in their original ordering. Figure 30.2(b) shows the matrix for 
the graph with the nodes in their permuted ordering. It is clear that there is block structure. 

We can make a generative model of block structured graphs as follows. First, for every node, 
sample a latent block q; ~ Cat(m), where mk is the probability of choosing block k, fork = 1: K. 
Second, choose the probability of connecting group a to group b, for all pairs of groups; let us denote 
this probability by 72,5. This can come from a beta prior. Finally, generate each edge Rj; using the 
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Figure 30.1: (a) A directed graph. (b) The same graph, with the nodes partitioned into 8 groups, making the 
block structure more apparent. 
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30 Figure 30.2: (a) Adjacency matriz for the graph in Figure 30.1(a). (b) Rows and columns are shown permuted 


3, ¢0 show the block structure. We also show how the stochastic block model can generate this graph. From 
jo figure 1 of [Kem+06]. Used with kind permission of Charles Kemp. 


33 
34 
35 following model: 


= p(Rij = r|qi = a, qj = b, n) = Ber(r|na,2) (30.1) 


38 This is called the stochastic block model [NSO1]. Figure 30.4(a) illustrates the model as a DGM, 
39and Figure 30.2 illustrates how this model can be used to cluster the nodes in our example. 

40 Note that this is quite different from a conventional clustering problem. For example, we see 
4ithat all the nodes in block 3 are grouped together, even though there are no connections between 
42them. What they share is the property that they “like to” connect to nodes in block 1, and to receive 
43 connections from nodes in block 2. Figure 30.3 illustrates the power of the model for generating many 
44different kinds of graph structure. For example, some social networks have hierarchical structure, 
45which can be modeled by clustering people into different social strata, whereas others consist of a set 
46 of cliques. 

47 
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Figure 30.3: Some examples of graphs generated using the stochastic block model with different kinds of 
connectivity patterns between the blocks. The abstract graph (between blocks) represent a ring, a dominance 
hierarchy, a common-cause structure, and a common-effect structure. From Figure 4 of [Kem+10]. Used with 
kind permission of Charles Kemp. 


Unlike a standard mixture model, it is not possible to fit this model using exact EM, because all 
the latent q; variables become correlated. However, one can use variational EM [Air+08], collapsed 
Gibbs sampling [Kem-+06], etc. We omit the details (which are similar to the LDA case). 

In [Kem+06], they lifted the restriction that the number of blocks K be fixed, by replacing the 
Dirichlet prior on m by a Dirichlet process (see Section 31.2). This is known as the infinite relational 
model. See Section 30.2.3 for details. 

If we have features associated with each node, we can make a discriminative version of this model, 
for example by defining 


p(Rij = rqi = a, qj = b, £i, £j, 0) = Ber(r|wa pf (£i, £3)) (30.2) 


where f(x, æj) is some way of combining the feature vectors. For example, we could use concatenation, 
[x;,x,], or elementwise product x; Q a; as in supervised LDA. The overall model is like a relational 
extension of the mixture of experts model. 


30.2.2 Mixed membership stochastic block model 


In [Air+08], they lifted the restriction that each node only belong to one cluster. That is, they 
replaced q; € {1,...,&} with m; € Sx. This is known as the mixed membership stochastic 
block model, and is similar in spirit to fuzzy clustering or soft clustering. Note that Tik is 
not the same as p(z; = k|D); the former represents ontological uncertainty (to what degree does 
each object belong to a cluster) whereas the latter represents epistemological uncertainty (which 
cluster does an object belong to). If we want to combine epistemological and ontological uncertainty, 
we can compute p(7;|D). 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 
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26 Figure 30.5: (a) Who-likes-whom graph for Sampson’s monks. (b) Mixed membership of each monk in one of 
a7three groups. From Figures 2-3 of [Air+08]. Used with kind permission of Edo Airoldi. 

28 

29 

30 In more detail, the generative process is as follows. First, each node picks a distribution over 
31 blocks, m; ~ Dir(@). Second, choose the probability of connecting group a to group b, for all pairs of 
32groups, Na,» ~ O(a, 8). Third, for each edge, sample two discrete variables, one for each direction: 


33 
34 dij ~ Cat(7;), qij ~ Cat(z;) (30.3) 


“ Finally, generate each edge R,; using the following model: 


37 
38 
39 See Figure 30.4(b) for the DGM. 

40 Unlike the regular stochastic block model, each node can play a different role, depending on who 
41it is connecting to. As an illustration of this, we will consider a data set that is widely used in the 
42social networks analysis literature. The data concerns who-likes-whom amongst of group of 18 monks. 
43It was collected by hand in 1968 by Sampson [Sam68] over a period of months. (These days, in 
44the era of social media such as Facebook, a social network with only 18 people is trivially small, 
45but the methods we are discussing can be made to scale.) Figure 30.5(a) plots the raw data, and 
46 Figure 30.5(b) plots E [mr]; for each monk, where K = 3. We see that most of the monks belong 
47 
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30.2. LATENT VARIABLE MODELS FOR GRAPHS 


to one of the three clusters, known as the “young turks”, the “outcasts” and the “loyal opposition”. 
However, some individuals, notably monk 15, belong to two clusters; Sampson called these monks the 
“waverers”. It is interesting to see that the model can recover the same kinds of insights as Sampson 
derived by hand. 

One prevalent problem in social network analysis is missing data. For example, if Rj; = 0, it may 
be due to the fact that person 7 and j have not had an opportunity to interact, or that data is not 
available for that interaction, as opposed to the fact that these people don’t want to interact. In 
other words, absence of evidence is not evidence of absence. We can model this by modifying the 
observation model so that with probability p, we generate a 0 from the background model, and we 
only force the model to explain observed 0s with probability 1 — p. In other words, we robustify the 
observation model to allow for outliers, as follows: 


p(Rij = r|di=>j = Q, lij = b,n) = pôo(r) + (1 — p)Ber(r|na,b) (30.5) 


See [Air+08] for details. 


30.2.3 Infinite relational model 


The stochastic block model is defined for graphs, in which each pair of edges may or may not have 
an edge. We can easily extend this to hyper-graphs, which is useful for modeling relational data. For 
example, suppose we want to model a family tree. We might write Rı(i, j,k) = 1 if adults ¿į and 7 
are the parents of child k, where R, is the “parent-of” relation. Here i and j are entities of type T! 
(adults), and j is an entity of type T? (child), so the type signature of R; is Tt x Tt x T? — {0,1}. 

To define the probability of relations holding between entities, we can associate a latent cluster 
variable qt € {1,..., Ay} with each entity i of each type t. We then define the probability of the 
relation holding between specific entities by looking up the probability of the relation holding between 
the corresponding entity clusters. Continuing our example above, we have 


p( Rr (i,j, k)\qi a, qj b, di C, n) Ber(Rı (i, J k)|\Na,b,c) (30.6) 


We can also have real-valued relations, where each edge has a weight. For example, we can write 


p(Ri (4, j, klal a, q} b, q C, u) N(Ri(i, j, k)| Ha, b,c; 0°), (30.7) 


where Ha,b,c captures the average response for that group of clusters. We can also add entity-specific 
offset terms: 


p(Rı(i, j, k)\qi a, Gj b, di, C, u) N(Rı(i, j, k)|Ha,b,c + Hi + Hj + Hk, 7), (30.8) 


This model was proposed in [BBM07], who fit the model using an alternating minimization procedure. 

If we allow the number of clusters K, for each type of entity to be unbounded, by using a Dirichlet 
process, the model is called the infinite relational model (IRM) [Kem+06], also known as an 
infinite hidden relational model (IHRM) [Xu+06]. We can fit this model with variational Bayes 
[Xu+06; Xu+07| or collapsed Gibbs sampling [Kem-+06]. Rather than go into algorithmic detail, we 
just sketch some interesting applications. 
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17 Figure 30.6: Illustration of an ontology learned by IRM applied to the Unified Medical Language System. The 
18 boxes represent 7 of the 14 concept clusters. Predicates that belong to the same cluster are grouped together, 
19and associated with edges to which they pertain. All links with weight above 0.8 have been included. From 


20 Figure 9 of [Kem+10]. Used with kind permission of Charles Kemp. 
21 
22 


“= 30.2.3.1 Learning ontologies 


25 An ontology refers to an organisation of knowledge. In AI, ontologies are often built by hand (see 
26e.g., [RN10]), but it is interesting to try and learn them from data. In [Kem+06], they show how 
27this can be done using the IRM. 

28 The data comes from the Unified Medical Language System [McC03], which defines a semantic 


29 06 


29network with 135 concepts (such as “disease or syndrome”, “diagnostic procedure”, “animal”), and 
3049 binary predicates (such as “affects”, “prevents”). We can represent this as a ternary relation 
31R:T! x T! x T? = {0,1}, where T! is the set of concepts and T° is the set of binary predicates. 
32The result is a 3d cube. We can then apply the IRM to partition the cube into regions of roughly 
33 homogeneous response. The system found 14 concept clusters and 21 predicate clusters. Some of these 
34are shown in Figure 30.6. The system learns, for example, that biological functions affect organisms 
35(since Na b,c © 1 where a represents the biological function cluster, b represents the organism cluster, 
36and c represents the affects cluster). 

37 


7 30:2-3-2 Clustering based on relations and features 


40 We can also use IRM to cluster objects based on their relations and their features. For example, 
41[Kem-+06] consider a political dataset (from 1965) consisting of 14 countries, 54 binary predicates 
42representing interaction types between countries (e.g., “sends tourists to”, “economic aid”), and 90 
43 features (e.g., “communist”, “monarchy”). To create a binary dataset, real-valued features were 
44thresholded at their mean, and categorical variables were dummy-encoded. The data has 3 types: T! 
45represents countries, T? represents interactions, and T? represents features. We have two relations: 
46 R! : T! x T! x T? > {0,1}, and R? : T! x T’ — {0,1}. (This problem therefore combines aspects 
47 
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Figure 30.7: Illustration of IRM applied to some political data containing features and pairwise interactions. 
Top row (a): the partition of the countries into 5 clusters and the features into 5 clusters. Every second 
column is labelled with the name of the corresponding feature. Small squares at bottom (b-i): these are 8 of 
the 18 clusters of interaction types. From Figure 6 of [Kem+06]. Used with kind permission of Charles Kemp. 


of both the biclustering model and the ontology discovery model.) When given multiple relations, 
the IRM treats them as conditionally independent. In this case, we have 

p(R?, R?|q!,q°, a°, 0) = p(R'|q', q°, 0)p(R?|q', a°, 0) (30.9) 

The results are shown in Figure 30.7. The IRM divides the 90 features into 5 clusters, the first of 
which contains “noncommunist”, which captures one of the most important aspects of this Cold-War 
era dataset. It also clusters the 14 countries into 5 clusters, reflecting natural geo-political groupings 
(e.g., US and UK, or the Communist Bloc), and the 54 predicates into 18 clusters, reflecting similar 
relationships (e.g., “negative behavior and “accusations”). 


30.3 Graphical model structure learning 


In this section, we discuss how to learn the structure of a probabilistic graphical model given sample 
observations of some or all of its nodes. That is, the input is an N x D data matrix, and the output 
is a graph G (directed or undirected) with Ng nodes. (Usually Ng = D, but we also consider the 
case where we learn extra latent nodes that are not present in the input.) 
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23 Figure 30.8: A sparse undirected Gaussian graphical model learned using graphical lasso applied to some flow 
24 cytometry data (from [Sac+05]), which measures the phosphorylation status of 11 proteins. The sparsity level 
25is controlled by A. (a) A = 36. (b) A= 27. (c)X=7. (d) X\=0. Adapted from Figure 17.5 of [HTF09]. 
26 Generated by ggm_lasso_ demo.ipynb. 

27 

28 


30.3.1 Applications 


31 There are three main reasons to perform structure learning for PGMs: understanding, prediction, 
32and causal inference (which involves both understanding and prediction), as we summarize below. 
33 Learning sparse PGMs can be useful for gaining an understanding of multiple interacting variables. 
34For example, consider a problem that arises in systems biology: we measure the phosphorylation 
35status of some proteins in a cell [Sac+05] and want to infer how they interact. Figure 30.8 gives an 
36example of a graph structure that was learned from this data, using a method called graphical lasso 
37[FHT08; MH12], which is explained in Supplementary Section 30.3.2. As another example, [Smi+06] 
38showed that one can recover the neural “wiring diagram” of a certain kind of bird from multivariate 
39time-series EEG data. The recovered structure closely matched the known functional connectivity of 
4othis part of the bird brain. 

41 In some cases, we are not interested in interpreting the graph structure, we just want to use it to 
42 make predictions. One example of this is in financial portfolio management, where accurate models of 
43 the covariance between large numbers of different stocks is important. [CW07] show that by learning 
44a sparse graph, and then using this as the basis of a trading strategy, it is possible to outperform (i.e., 
45make more money than) methods that do not exploit sparse graphs. Another example is predicting 
46 traffic jams on the freeway. [Hor+05] describe a deployed system called JamBayes for predicting 
47 
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traffic flow in the Seattle area, using a directed graphical model whose structure was learned from 
data. 

Structure learning is also an important pre-requisite for causal inference. In particular, to predict 
the effects of interventions on a system, or to perform counterfactual reasoning, we need to know the 
structural causal model (SCM), as we discuss in Section 4.7. An SCM is a kind of directed graphical 
model where the relationships between nodes are deterministic (functional), except for stochastic 
root (exogeneous) variables. Consequently one can use techniques for learning DAG structures as a 
way to learn SCMs, if we make some assumptions about (lack of) confounders. This is called causal 
discovery. See Supplementary Section 30.4 for details. 


30.3.2 Methods 


There are many different methods for learning PGM graph structures. See Supplementary Chapter 30 
for details. 
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3 1 Non-parametric Bayesian models 


This chapter is written by Vinayak Rao. 


31.1 Introduction 


The defining characteristic of a parametric model is that the objects being modeled, whether 
regression or classification functions, probability densities, or something more modern like graphs or 
shapes, are indexed by a finite-dimensional parameter vector. For instance, neural networks have a 
fixed number of parameters, independent of the dataset. In a parametric Bayesian model, a prior 
probability distribution on these parameters is used to define a prior distribution on the objects of 
interest. By contrast, ina Bayesian nonparametric (BNP) model (also called a non-parametric 
Bayesian model) we directly place prior distributions on the objects of interest, such as functions, 
graphs, probability distributions, etc. This is usually done via some kind of stochastic process, 
which is a probability distribution over a potentially infinite set of random variables. 

One example is a Gaussian process. As explained in Chapter 18, this defines a probability 
distribution over an unknown function f : ¥ — R, such that the joint distribution of f(X) = 
(f(a1),.--,f(a@y)) is jointly Gaussian for any finite set of values X = {æ E€ X} i.e., p(f(X)) = 
N(f(X)|u(X), K(X)) where p(X) = [u(a1),..., u(æn)] is the mean, K(X) = [K(a,),K(a;)] is the 
N x N Gram matrix, and K is a positive definite kernel function. The complexity of the posterior 
over functions can grow with the amount of data, avoiding underfitting, since we maintain a full 
posterior distribution over the infinite set of unknown “parameters” (i.e., function evaluations at all 
points x € X). But by taking a Bayesian approach, we avoid overfitting this infinitely flexible model. 
Despite involving infinite-parameter objects, practitioners are often only interested in inferences on a 
finite training dataset and predictions on a finite test dataset. This often allows these models to be 
surprisingly tractable. We can also define probability distributions over probability distributions, as 
well as other kinds of objects, as we discuss in the sections below. For more details, see e.g., [Hjo+10; 
GV17]. 


31.2 Dirichlet processes 
A Dirichlet process (DP) is a nonparametric probability distribution over probability distributions, 


and is useful as a flexible prior for unsupervised learning tasks like clustering and density modeling 
[Fer73]. We give more details in the sections below. 


1010 


Figure 31.1: Partitions of the unit square. (left) One possible partition into K = 3 regions, and (center) A 
refined partition into K = 4 regions. In both figures, the shading of cell Ty is proportional G(Th), resulting 
10 from the same realization of a Dirichlet process. (right) An ‘infinite partition’ of the unit square. The Dirichlet 
11 process can informally be viewed as an infinite-dimensional Dirichlet distribution defined on this. 

12 

13 

14 

1531.2.1 Definition of a DP 

16 

17 Let G be a probability distribution or a probability measure (we will use the latter terminology in 


jgthis chapter) on some space ©. Recall that a probability measure is a function that assigns values 
ig to subsets T C © satisfying the usual axioms of probability: 0 < G(T) < 1, G(®) = Jo G(0)dé = 1, 
20and for disjoint subsets T),...,TK of O, G(T, U ... U Tg) = ee G(T). Bayesian unsupervised 
21 learning now seeks to place a prior on the probability measure G. 

22 We have already seen examples of parametric priors over probability measures. As a simple example, 
23consider a Gaussian distribution M (|u, o°): this is a probability measure on O, and by placing 
2apriors on the parameters u and g?, we have a parametric prior on probability measures. Mixture 
25models form more flexible priors, allowing multimodality and asymmetry, and are parametrized 
26 by the probabilities of the mixture components, as well as their parameters. DPs directly define a 
27 probability on probability measures G. 

28 A Dirichlet Process is specified by a positive real number a, called the concentration parameter, 
29and a probability measure H, called the base measure. We write a random measure drawn from 
30a DP as G ~ DP(a, H). H is typically a standard probability measure on ©, and forms the mean 
310f the Dirichlet process. That is, if G ~ DP(a, H), then for any subset T of O, E[G(T)] = H(T). 
32The parameter a measures how concentrated the Dirichlet process is around H, with V [G(T)| = 
33 Oe), If © is R?, then setting H to the bivariate normal M (0, I2) and a to a large value 
34implies a prior belief that G sampled from DP(a, H) is close to the normal, whereas a small a 
35represents a relatively uninformative prior. 

36 We now define the Dirichlet process more precisely. Let (T),...,T7«) be a finite partition of 
370, that is, (Tı,...,Tg) are disjoint sets whose union is ©. For a probability measure G, let 
38(G(T,),...,G(Tx)) be the vector of probabilities of the elements of this partition. Then DP(a, H) is 
39a prior over probability measures G satisfying the following requirement: for any finite partition, the 
40 associated vector of probabilities has the following joint Dirichlet distribution: 

41 

42 (G(Tı),...,G(Tg)) ~ Dir(aH (T1), ...,aH(Tg)). (31.1) 
43 

44 Just like the Gaussian process, the DP is defined implicitly though a set of finite-dimensional 
45distributions, in this case through the distribution of G projected onto any finite partition. The 
46 finite-dimensional distributions are consistent in the following sense: if JT), and Tiz form a partition 
47 
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31.2. DIRICHLET PROCESSES 


of Tı, then one can sample G(T) in two ways: directly, by sampling 

(G(T1),...,G(Lk)) ~ Dir(aH(T)),...,aH(Tr)) (31.2) 
or, indirectly, by sampling 

(G(T11), G(Ti2),..., G(Tx)) ~ Dir(aH(Ti1),aA(Ti2),...,vH (TK)) (31.3) 


and then setting G(T) = G(T\1) + G(Ti2). From the properties of the Dirichlet distribution, G(T,) 
sampled either way follows the same distribution. This consistency property implies, via Kolmogorov’s 
extension theorem [Kal06], that underlying all finite-dimensional probability vectors for different 
partitions is a single infinite-dimensional vector that we could informally write as 


G(d6,),...,G(8c0) ~ Dir(aH (d6,),..., aH (d0o0)). (31.4) 


Very roughly, this ‘infinite-dimensional Dirichlet distribution’ is the Dirichlet Process. Figure 31.1 
sketches this out. 

Why is the Dirichlet process, defined in this indirect fashion, useful to practitioners? The answer 
has to do with conjugacy properties that it inherits from the Dirichlet distribution. One of the 
simplest unsupervised learning problems seeks to learn an unknown probability distribution G from 
iid. samples {01,..., 0y} drawn from it. Consider placing a DP prior on the unknown G. Then 
given the data, one is interested in the posterior distribution over G, representing the updated 
probability distribution over G. For a partition (T,,..., 7) of O, an observation falls into the cell z 
following a multinoulli distribution: 


z ~ Cat(G(T)),...,G(Tx)). (31.5) 


Under a DP prior on G, (G(T1),...,G(Zx)) follows a Dirichlet distribution (equation (31.1)). From 
the Dirichlet-multinomial conjugacy, the posterior for (G(T1),..., G(Tg)) given the observations is 


N N 
(G(T,),-..,G(Tx))|{01,..., On} ~ Dir(G(T1)+> “1G; € Th) ,..., (Lk) +5 I (0; € Tx)) (31.6) 


i=l i=1 


This is true for any finite partition, so that following our earlier definition, the posterior over G itself 
is a Dirichlet process, and it is easy to see that: 


N 
= = 1 
G\0,,...,0v,a,H ~ DP (eamp (ov dom)). (31.7) 


Thus we see that the DP prior on G is a conjugate prior for i.i.d. observations from G, with the 
posterior distribution over G also a Dirichlet process with concentration parameter a + N, and base 
measure a convex combination of the original base measure H and the empirical distribution of 
the observations. Note that as N increases, the influence of the original base measure H starts to 
wane, and the posterior base measure becomes closer and closer to the empirical distribution of the 
observations. At the same time, the concentration parameter increases, suggesting that the posterior 
distribution concentrates around the empirical distribution. 
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: eiti iele 
8 (a) (b) 
2 Figure 31.2: Realizations from a Dirichlet process when © is (a) the real line, and (b) the unit square. Also 
10 shown are the base measures H. In reality, the number of atoms is infinite for both cases. 
11 
2 B: 1-ß, a=1 a=2 
13 ea a 0.6] 
4 1 È 1-8, 0.4; 
= y s | 
TT 0.2 
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18 Fr } i 0.41 
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19 A] 0.2 | J 
19 T, © 0 ee 
20 E 0 10 20 30 0 10 20 30 
a b 
22 (a) (b) 


23 Figure 31.3: Illustration of the stick breaking construction. (a) We have a unit length stick, which we break at a 
24random point 31; the length of the piece we keep is called mı; we then recursively break off pieces of the remaining 
25 stick, to generate T2, T3,.... From Figure 2.22 of [Sud06]. Used with kind permission of Erik Sudderth. (b) 
26 Samples of nk from this process for different values of a. Generated by stick_ breaking _ demo.ipynb. 

27 

28 

2931.2.2 Stick breaking construction of the DP 

30 

31 Our discussion so far has been very abstract, with no indication of how to either sample the random 
32measure G or how to sample observations from G. We address the first question, giving a constructive 
33definition for the DP known as the stick-breaking construction [Set94]. 

34 We first mention that probability measures G sampled from a DP are discrete with probability 
35 one (see Figure 31.2), taking the form 


ae o0 

37 2 

Fe G(0) = 2 7.00, (0). (31.8) 
39 


40 Thus G consists of an infinite number of atoms, the kth atom located at 6,, and having weight Tp. 
41 Informally, this follows from Equation (31.4), which represents the DP as an infinite-dimensional 
42but infinitely-sparse Dirichlet distribution (recall that as its parameters become smaller, a Dirichlet 
43 distribution concentrates on sparse distributions that are dominated by a few components). 

44 For a DP, the locations 6; of the atoms are drawn independently from the base measure H, whereas 
45the concentration parameter a controls the distribution of the weights m. Observe that the infinite 
46 sequence of weights (71,72,...) must add up to one, since G is a probability measure. The weights 
47 
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31.2. DIRICHLET PROCESSES 


can be simulated by the following process sketched in Figure 31.3, and known as the stick-breaking 
process. Start with a stick of length 1 representing the total probability mass, and sequentially 
break off a random Beta(1, œ) distributed fraction of the remaining stick. The kth break forms ry. 
In equations, for k = 1,2,.. 


ae 


Br oS Beta(1, a), Ok ms A, (31.9) 
k-1 k-1 

ne = Be [[ Q— 8) = Bet — $5 m1) (31.10) 
l=1 l=1 


Then, setting G(0) = $}; Tkôo, (0), one can show that G ~ DP(a, H). The distribution over the 
weights is often denoted by 


r ~ GEM(a), (31.11) 


where GEM stands for Griffiths, Engen and McCloskey (this term is due to [Ewe90]). Some samples 
from this process are shown in Figure 31.3. 

We note that since the number of atoms in infinite, one cannot exactly simulate from a DP in finite 
time. However, the sequence of weights from the GEM distribution are stochastically ordered, 
having decreasing averages, and the truncation error resulting from terminating after a finite number 
of steps quickly becoming neglible [IJ01]. Nevertheless, we will see in the next section that it is 
possible to simulate samples and make predictions from a DP-distributed probability measure G 
without any truncation error. This exploits the conjugacy propery of the DP. 


31.2.3 The Chinese restaurant process (CRP) 


Consider a single observation 0, from a DP-distributed probability measure G. The probability 
that 0, lies within a set T C ©, marginalizing out the random G, is E[G(T)] = H(T), the equality 
following from the definition of the DP. This holds for arbitrary T, which implies that the first 
observation 61 is distributed as the base measure of the DP: 


pð = Ola, H) = H(6). (31.12) 


Given N observations 6;,...,9N, the updated distribution over G is still a DP, but now modified 
as in Equation (31.7). Repeating the same argument, it follows that the (N + 1)st observation is 
distributed as the base measure of the posterior DP, given by 


K 
D(On41 = 0l01:N, a0, H) = a < WV (ano + >, Nx, o) (31.13) 
where N; is the number of observations equal to 8y. The previous two equations form the basis of 
what is called the Pólya urn or Blackwell-MacQueen sampling scheme [BM-+73]. This provides 
a way to exactly produce samples from a DP-distributed random probability measure. 

It is often more convenient to work with discrete variables (z,,..., zy), with z; specifying which 
value of 6; the ith sample takes. In particular, for the ith observation, 6; = 0.,. This allows us to 
decouple the cluster or partition structure of the dataset (controlled by a) and the cluster parameters 
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1 

2 (controlled by H). Let us assign the first observation to cluster 1, i.e. z1 = 1. The second observation 
3 can either belong to the same cluster as observation 1, or belong to a new cluster, which we call 
4 Cluster 2. In the former event, z2 = 1, after which z3 can equal 1 or 2. In the latter event, z2 = 2, 
5 and z3 can equal 1,2 or 3. Based on the Equation (31.13), we have 

6 

: : I A 

3 Pny = 2/2, @) = = raat (g=K+4+1)+ dM (z=k)], (31.14) 
9 


10assuming the first N observations have been assigned to K clusters. This is called the Chinese 
llrestaurant process or CRP, based on the following analogy: observations are customers in a 
22restaurant with an infinite number of tables, each corresponding to a different cluster. Each table 
13has a dish, corresponding to the parameter 0 of that cluster. When a customer enters the restaurant, 
14they may choose to join an existing table with probability proportional to the number of people 
already sitting at this table (i.e. they join table k with probability proportional to Ng); otherwise, 
16with probability proportional to a, they choose to sit at a new table, ordering a new dish by sampling 
17from the base measure H. 

18 The sequence Z = (21,...,2y) of cluster assignments is partition of the integers 1 to N, and 
19the CRP is a distribution over such partitions. The probability of creating a new table diminishes as 
20the number of observations increases, but is always non-zero, and one can show that the number of 
21occupied tables K approaches alog(N) as N — oo almost surely. The fact that currently occupied 
22tables are more likely to get new customers is sometimes called a rich get richer phenomenon. 
23Tt is important to recognize that despite being defined as a sequential process, the CRP is an 
24exchangeable process, with partition probabilities that are independent of the observation indices. 
25 Indeed, it is easy to show that the probability of a partition of N integers into K clusters with sizes 
26.N,,...,Nx is 


27 

28 aki 

22 pM... Nk) = ari J [y (31.15) 
30 l k=1 

31 


z2 Here, [a + 1] = I (e + 1 + i) is the rising factorial. Equation (31.15) depends only on the 
33 cluster sizes, and is called the Ewens sampling formula [Ewe72]. Exchangeability implies that the 
34 Probability that the first two customers sit at the same table is the same as the probability that the 
3, urst and last sit at the same table. Similarly all customers have the same probability of ending up in 
36 Cluster of size S. The fact that the first customer can only belong to cluster 1 (i.e. that z1 = 1) 
37 does not contradict exchangeability and reflects the fact that the cluster indices are chosen arbitrarily. 


3g This disappears if we index clusters by their associated parameter 0x. 

39 

431.3 Dirichlet process mixture models 

41 

42 Real-world datasets are often best modeled by continuous probability densities. By contrast, a 
43sample G from a DP is discrete with probability one, and sampling observations from G will result 
44in repeated values, making it inappropriate for many applications. However, the discrete structure 
450f G is useful in clustering applications, as a prior for the latent variables underlying the observed 
46 datapoints. In particular, z; and 0; can represent the cluster assignment and cluster parameter of 
47 
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alpha= 1.0 , N= 50 alpha= 2.0 , N= 50 


(a) 
alpha= 1.0 , N= 500 


O 


(c) 
alpha= 1.0 , N= 1000 


A 


(e) (f) 


Figure 31.4: Some samples from a Dirichlet process mixture model of 2D Gaussians, with concentration 
parameter a = 1 (left column) and a = 2 (right column). From top to bottom, we show N = 50, N = 500 
and N = 1000 samples. Generated by dp_ mixgauss_ sample.ipynb. 
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1 
2 the ith datapoint, whose value æ; is a draw from some parametric distribution F (æ|0) indexed by 0, 
3 with base measure H. The resulting model follows along the lines of a standard mixture model, but 
4 now is an infinite mixture model, consisting of an infinite number of components or clusters, one 
5 for each atom in G. 
6 A very common setting when x; € Rt is to set F to be the multivariate normal distribution, 
7 0 = (u, ©), and H to be the Normal-Inverse-Wishart distribution. Then, each of the infinite clusters 
g has an associated mean and covariance matrix, and to generate a new observation, one picks cluster k 
9 with probability 7,, and simulates from a normal with mean u, and covariance Xp. See Figure 31.4 
1ofor some samples from this model. 
11 We discuss DP mixture models (DPMM) in more detail below. 
12 
1331.3.1 Model definition 
4 
1, We define the DPMM model as follows: 
£ a~ GEM(a), O ~ 4H, k=1,2,... (31.16) 
7 
is iNT, zi ~ F(O@,,), i=1,..., N. (31.17) 
+2 Equivalently, we can write this as 
20 
21 Gw~DP(a, H) (31.18) 
n Di~ G, i~ F@), i=1,...,N. (31.19) 


24G and F together define the infinite mixture model: Gr(æ) ~ Xz mkF (x|). If F(æ|0) is 
25 continuous, then so is Gp (x), and the Dirichlet process mixture model serves as a nonparametric 
26 prior over continuous distributions or probability densities. 

27 Figure 31.5 illustrates two graphical models that summarize this, corresponding to the two sets 
28 of equations above. The first generates the set of weights (71, 72,...) from the GEM distribution, 
29 along with an infinite collection of cluster parameters (01, 02,...). It then generates observations by 
30 first sampling a cluster indicator z; from 7, indexing the associated cluster parameter @,, and then 
31 simulating the observation x; from F'(@.,). The second graphical model simulates a random measure 
32G from the DP. It generates observations by directly simulating a parameter 6; from G, and then 
3simulating x; from F(0;). The infinite mixture model can be viewed as the limit of a K-component 
34finite mixture model with a Dir(a/K,...,a/K) prior on the mixture weights (71,...,7«) and with 
35 mixture parameters 01,...,9%, as K — oo [Ras00; Nea00]. Producing exact samples (Œ1,..., £y) 
36 from this model involves one additional step to the Chinese restaurant process: after selecting a table 
37(cluster) z; with its associate dish (parameter) 0,,, the ith customer now samples an observation 


38from the distribution F'(0;, ). 
39 


31.3.2 Fitting using collapsed Gibbs sampling 


42 Given a dataset of observations, one is interested in the posterior distribution p(G, 21,...,2n|@1,.--,@n,a, H), 
43 or equivalently, p(w, 01,02,...,21,---,2n|@1,---,@n,Q@,H). The most common way to fit a DPMM 

44is via Markov chain Monte Carlo (MCMC), producing samples by constructing a Markov chain 

45that targets this posterior distribution. We describe a collapsed Gibbs sampler based on the 

46 Chinese restaurant process that marginalizes out the infinite-dimensional random measure G, and 

47 
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Ch ® 

CPO | ral | 
ƏM aa, 
2 8| MA 


(a) 


Figure 31.5: Two views of N observations sampled from a DP misture model. Left: representation where 
cluster indicators are sampled from the GEM-distributed distribution n. Right: representation where parameters 
are samples from the DP-distributed random measure G. The rightmost picture illustrates the case where 
N = 2, 0 is real-valued with a Gaussian prior H(-), and F(x|0) is a Gaussian with mean 0 and variance o°. 
We generate two parameters, 0; and 02, from G, one per data point. Finally, we generate two data points, 
xı and x2, from N(61,07) and N(02,0°). From Figure 2.24 of [Sud06]. Used with kind permission of Erik 
Sudderth. 


that targets the distribution p(z1,...,zy|@1,...,@n,@, H) summarizing all clustering information. 
It produces samples from this distribution by cycling through each observation x;, and updating its 
assigned cluster z;, conditioned on all other variables. Write x_; for all observations other than the 
ith observation, and z_,; for their cluster assignments. Then we have 


plzi = k|z-i, x,a, H) x p(z; = k|z-i, a)p(wi|e_i, zi = k, z-i, H) (31.20) 


By exchangeability, each observation can be treated as the last customer to enter the restaurant. 
Hence the first term is given by 


K-i 
1 
plzi|Z—i, a) = CNS al (zi = K + 1) + 5 Npk, —il (zi = k) (31.21) 
k=1 


where Nķ,—i is the number of observations in cluster k, and K_; is the number of clusters used 
by x_;, both obtained after removing observation i, eliminating empty clusters, and renumbering 
clusters. 

To compute the second term, p(x;|x_;, 2; = k, z_;, H), let us partition the data x_, into clusters 
based on z_;. Let v_j,¢ = {@; : zj = c, j # i} be the data points assigned to cluster c. If z; = k, 
then x; is conditionally independent of all the data points except those assigned to cluster k. Hence, 


tore Tik) 


d (31.22) 


p(x;|e_i, 2-1, zi = k) = p(xi|£—i k) = 
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where 


pæn) = folelo | TE pæl] Odos (31.23) 
jżi:zj=k 


is the marginal likelihood of all the data assigned to cluster k, including i, and p(®£—i x) is an analogous 
expression excluding i. Thus we see that the term p(xiļ£—i, z-;, z; = k) is the posterior preditive 
distribution for cluster k evaluated at x;. 

If z; = k*, corresponding to a new cluster, we have 


IO 100 IN ID Jo Ie low IN Ie 


p(x;|x_i, Zi, zi = k*) = p(w) = J veon oao (31.24) 


which is just the prior predictive distribution for a new cluster evaluated at x;. 

The overall sampler is sometimes called “Algorithm 3” (from [Nea00]). Algorithm 43 provides the 
17 pseudocode. The algorithm is very similar to collapsed Gibbs for finite mixtures except that we have 
įg tO consider the case z; = k*. Note that in order the evaluate the integrals in Equation (31.23) and 
qg Equation (31.24), we require the base measure H to be conjugate to the likelihood F. For example, 
201f we use an NIW prior for the Gaussian likelihood, we can use the results from Section 3.3.3.6 to 
9, compute the predictive distributions. Extensions to the case of non-conjugate priors are discussed in 
22 [Nea00]. 
23 
24 Algorithm 43: Collapsed Gibbs sampler for DP mixture model 


251 foreach i = 1 : N in random order do 


26 2 Remove «;’s sufficient statistics from old cluster z; 
273 foreach k = 1: K do 

28 4 Compute px (xi) = p(xi|e—i(k)) 

295 Set Nk,—i = dim(a_;(k)) 

e Compute p(z; = k|z-i,D) = A Pe (wi) 
327 Compute p(z; = *|z_;,D) = aENTIP(@i) 

33 8 Normalize p(z;|-) 

349 Sample z; ~ p(zi|-) 

3310 Add 2;’s sufficient statistics to new cluster z; 
361 If any cluster is empty, remove it and decrease K 
37 

38 

39 


4031.3.3 Fitting using variational Bayes 
T This section was written by Xinglong Li. 


43 In this section, we discuss how to fit a DP mixture model using mean field variational Bayes 
44(Section 10.2.3), as described in [BJ+06]. 

45 Given samples £1,..., £y from DP mixture, the mean field variational inference (MFVI) algorithm 
46is based on the stick-breaking representation of the DP mixture. The target of the inference is the joint 
47 
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(a) (b) 
DP 


0.8 


0.6 


0.4 


0.2 


0.0 -4 


(c) (a) 


Figure 31.6: Output of the DP mixture model fit using Gibbs sampling to two different datasets. Left column: 
dataset with 4 clear clusters. Right column: dataset with an unclear number of clusters. Top row: single 
sample from the Markov chain. Bottom row: empirical fraction of times a given number of cluster is used, 
computed across all samples from the chain. Generated by dp_ mixgauss_ sample.ipynb. 


posterior distribution of the beta random variables 8 = {61, G2...} in the stick-breaking construnction 
of DP , the locations 0 = {01,02,...} of atoms , and the cluster assignments z = {z1,..., ZN}: 


w = {68,0,z} (31.25) 


The hyperparameters are the concentration parameter of the DP and the parameter of the conjugate 
base distribution of 0: 


r= {a,n} (31.26) 
The variational inference algorithm minimizes the KL-divergence between qy(w) and p(w|æ, A): 
Dri (qy (w) || p(w|æ, A)) = E, llog qy (w)] — Eq[log p(w, £|A)] + log p(a|A) (31.27) 
Minimizing the KL divergence is equivalent to maximizing the evidence lower bound (ELBO): 
E =E, [log p(w, 2|A)] — Egllog ay (w)] (31.28) 
N 
=F, [log p(Bla)] + Egllog p(6|n)] + Y` (Eallog p(2n|8)] + Eclog plan |2n)]) (31.29) 
n=1 
— E, [log qy (8,0, z)] (31.30) 
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To deal with the infinite parameters in 8 and 0, the variational inference algorithm truncates the 
DP by fixing a value T and setting g(Gr = 1) = 1, which implies that 7, = 0 for t > T. Therefore, 
qy(w) in the MFVI for DP mixture models factorizes into 


T-1 T N 
qy(B, 9,2) = JI a, [] a00) I [ a, Cn), (31.31) 


n=1 


IO 100 IN ID Io Ie low IN Ie 


where qy, (6+) is the beta distribution with parameters {7,1, 7,2}, qr, (0+) is the exponential family 
io distribution with natural parameters T+, and qg, (Zn) is the categorical distribution of the cluster 
iiassignment of observation x, with q(zn = t) = n. The free variational parameters are 


Y = {y »>YT=-1:T15°°° TT Py, Pn}. (31.32) 


Notice that only qy(w) is truncated, the true posterior p(w, æ|A) from the model need not be 
truncated when minimizing the KL. 

The MFVI can be optimized via the coordinate ascent algorithm, and the closed form update in 
each step exists when the base measure of the DP is conjugate to the likelihood of observations. In 
+8 particular, suppose that conditional distribution of £n conditioned on zn and @ is an exponential 
family distribution (Section 2.3): 


In le la le le Is 


20 
21 p(&n|2n, 01, O2,...) = h(an) exp{O!, £n — a(z,,)} (31.33) 
22 


23 Where £n is the sufficient statistic for the natural parameter 0,,. Therefore, the conjugate base 
24 distribution is 


p(9\n) = h(8) exp(n10 + n2(—a(8)) — a(n)), (31.34) 


27 where 7, contains the first dim(@) components and 72 is a scalar. See Algorithm 44 for the resulting 
28 pseudocode. 
29 Extensions of this method to infer the hyperparameters, A = {a,7}, can be found in the appendix 


30 of [BJ+06]. 
31 


= 31.3.4 Other fitting algorithms 


25 
26 


34 While collapsed Gibbs ssmpling is the simplest approach to posterior inference for DPMMs, a variety 
35o0f other methods have been proposed as well. One popular class of MCMC samplers works with the 
36 stick-breaking representation of the DP instead of CRP, instantiating the random measure G [IJOL]. 
37The sampler then proceeds by sampling the cluster assignments z given G, and then G given z. An 
38advantage of this is that the cluster assignments can be updated in parallel, unlike the CRP, where 
39they are updated sequentially. To be feasible however, these methods require truncating G to a 
40 finite number of atoms, though the resulting approximation error can be quite small. The posterior 
41 approximation error can be eliminated altogether by slice-sampling methods [KGW 11], that work 
42 with random truncation levels. 

43 Alternatives to MCMC also exist. [Dau07] shows how one can use A* search and beam search to 
44 quickly find an approximate MAP estimate. [Man+07] discusses how to fit a DPMM online using 
45 particle filtering, which is a like a stochastic version of beam search. This can be more efficient than 
46 Gibbs sampling, particularly for large datasets. 

47 
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31.4. GENERALIZATIONS OF THE DIRICHLET PROCESS 


Algorithm 44: Variational inference for DP mixture model 


1 Initialize the variational parameters: 


2 @nt is membership probability of £n in cluster t; 

3 T are the natural parameters for cluster t; 

4 +, are the parameters for the stick breaking distribution. 

5 while not converged do 

6 foreach y, do 

7 Update the beta distribution qy, (6+); 

8 Vt, = 1+ Dn Pn,t 

9 | w254 + a Do nj 
10 foreach 7; do 
11 Update the exponential family distribution q,,(0;) given sufficient statistics {x, }; 
12 Tt =N + don ntEn 
13 | 7,2 = N2 + SS. On,t 
14 foreach ¢,, do 
15 Update the categorical distribution qg, (zn) for each observation; 
16 Ona X exp( St) 
17 | |S; = Egflog bi] + Xiz: Eqllog(1 — 6;)] + Eg l0] En — Eqla(0:)] 


In Section 31.3.3 we discussed an approach based on mean field variational inference. A variety of 
other variational approximation methods have been proposed as well, for example [KWV06; TKW08; 
Zob09; WB12]. 


31.3.5 Choosing the hyper-parameters 


An important issue is how to set the model hyper-parameters. These include the DP concentration 
parameter a, as well as any parameters A of the base measure H. For the DP, the value of a does 
not have much impact on predictive accuracy, but has quite a strong affect the number of clusters. 
One approach is to put a Gamma prior for a, and then form its posterior, p(a|K, N, a,b) [EW95]. 
Simulating a given the cluster assignments z is quite straightforward, and can be incorporated into 
the earlier Gibbs sampler. The same is the case with the hyper-parameters A [Ras00]. Alternatively, 
one can use empirical Bayes [MBJ06] to fit rather than sample these parameters. 


31.4 Generalizations of the Dirichlet process 


Dirichlet process mixture models are flexible nonparametric models of continuous probability densities, 
and if set up with a little care, can possess important frequentist properties like asymptotic 
consistency: with more and more observations, the posterior distribution concentrates around the 
‘true’ data generating distribution, with very little assumed about the this distribution. Nevertheless, 
DPs still represent very strong prior information, especially in clustering applications. We saw that the 
number of clusters in a dataset of size N a priori is around alog N. As indicated by Equation (31.15), 
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not just the number of clusters, but also the distribution of their sizes is controlled by a single 
parameter a. The resulting clustering model is thus quite inflexible, and in many cases, inappropriate. 
Two examples from machine learning that highlight its limitations are applications involving text and 
image data. Here, it has been observed empirically that the number of unique words in a document, 
the frequency of word usage, the number of objects in an image, or the number of pixels in an object 
follow power-law distributions. Clusterings sampled from the CRP do not produce this property, and 
the resulting model mismatch can result in poor predictive performance. 
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1031.4.1 Pitman-Yor process 


1 
_,A popular generalization of the Dirichlet process is the Pitman-Yor process [PY97] (also called the 


two-parameter Poisson-Dirichlet process). Written at PYP(a,d, H), the Pitman-Yor process includes 
an additional discount parameter d which must be greater than 0. The concentration parameter 
can now be negative, but must satisfy a > —d. As with the DP, a sample G from a PYP is a 
random probability measure that is discrete almost surely. It has a stick-breaking representation 
that generalizes that of the DP: again, we start with a stick of length 1, but now at step k, a random 
ig Beta(1 — d,a + kd) fraction of the remaining probability mass is broken off, so that G is written as 


S le la le |e Is 


2 G= X > m0, (31.35) 
k=1 


2 A Beal deaths, Gal, (31.36) 

42 k—1 k—1 

2 te = Be [[ A - By) a- m). (31.37) 
i=l 


l=1 
26 


27 Because G is discrete, once again observations f4, .. . , ĝa sampled i.i.d. from G will possess a clustering 
2sstructure. These can directly be sampled following a sequential process that generalizes the CRP. 
29 Now, when a new customer enters the restaurant, they join an existing table with N; customers 
30with probability proportional to (N;, — d), and create a new table with probability proportional to 
31a + Kd, where K is the number of clusters. 

32 Observe that the Dirichlet Process is a special instance of the Pitman-Yor process, corresponding 
33to d equal to 0. Non-zero settings of d counteract the rich-get-richer dynamics of the DP to some 
34extent, increasing the probability of creating new clusters. The more clusters there are present in a 
35 dataset, the higher the probability of creating a new cluster (relative to the Dirichlet process). This 
36 behavior means that a large number of clusters, as well as a few very large clusters are more likely 
37under the PYP than the DP. 

33 An even more general class of probability measures are the so-called Gibbs-type priors [DB+13]. 
39 Under such a prior, given N observations, the probability these are clustered into K clusters, the kth 
4o having nk observations, is 


Al 

e K 

42 

gg POr. nk) = Vu | [(¢-y,-1, (31.38) 
PE k=1 

44 


45where (a)n = o(a + 1)...(0 +n — 1), and for non-negative weights Vy, satisfying Vyn, = 
46(N — 0 K)Vw4i.K + Vn4i,K41 and Vii = 1. This definition ensures that the probability over 
47 
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31.4. GENERALIZATIONS OF THE DIRICHLET PROCESS 


partitions is consistent, exchangeable and tractable, and includes the DP and PYP as special cases. 
Besides these two, Gibbs-type priors include settings where the number of components (or the number 
of atoms in the random measures) are random but bounded. Recall that DP and PYP mixture 
models are infinite mixture models, with the number of components growing with the number of 
observations. A sometimes undesirable feature of these models is that if a dataset is generated from 
a finite number of clusters, these models will not xs recover the true number of clusters [MH14]. 
Instead, the estimated number of clusters will increasexs with the size of the dataset, resulting in 
redundant clusters that are located very close to each other. Gibbs-type priors with almost surely 
bounded number of components can learn the true number of clusters while still remaining reasonably 
tractable: so long as one can calculate the Vy, terms, MCMC for all these models can by carried 
out by modifications of the CRP-based sampler described earlier. 


31.4.2 Dependent random probability measures 


Dirichlet processes, and more generally, random probability measures, have also been generalized 
from settings with a single set of observations to those involving grouped observations, or observations 
indexed by covariates. Consider T sets of observations {X!,...,X7}, each perhaps corresponding to 
a different country, a different year, or more abstractly, a different set of observed covariates. There 
are two immediate (though inadequate) ways to model such data: 1) by pooling all datasets into a 
single dataset, modeled as drawn from a DP or DPMM-distributed random probability measure G, 
or 2) by treating each group as independent, having its own random probability measure G*. The 
first approach fails to capture differences between the groups, while the second ignores similarities 
between the groups (e.g. shared clusters). Dependent random probability measures seek a compromise 
between these extremes, allowing statistical information to be shared across different groups. 

A seminal paper in the machine learning literature that addresses this problem is the hierarchical 
Dirichlet process (HDP) [Teh+06b]. The basic idea here is simple: each group has its own random 
measure drawn from a Dirichlet process DP(a, H). The twist now is that the base measure H itself 
is random, in fact it is itself drawn from a Dirichlet process. Thus, the overall generative process is 


H ~ DP (ao, Ho), (31.39) 
Gt ~ DP(a, H), t€1,...,T (31.40) 
T; ~ Gt, CE Locus Na tE1,... T. (31.41) 


The 0,8 might be the observations themselves, or the latent parameter underlying each observation, 
with at ~ F(0,). 

Recall that if a probability measure G! is drawn from DP(a, H), its atoms are drawn independently 
from the base measure H. In the HDP, H, which is a draw from a DP, is discrete, so that some 
atoms of G! will sit on top of each other, becoming a single atom. More importantly, all measures 
Gt will share the same infinite set of locations: each atom of H will eventually be sampled to form a 
location of an atom of each G*. This will allow the same clusters to appear in all groups, though they 
will have different weights. Moreover, a big cluster in one group is a priori likely to be a big cluster 
in another group, as underlying both is a large atom in H. Since the common probability measure 
H itself is random, its components (both weights as well as locations) will be learned from data. 
Despite its apparent complexity, it is fairly easy to develop an analogue of the Chinese restaurant 
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process for the HDP, allowing us to sample observations directly without having to instantiate any of 
the infinite-dimensional measures. This is called the Chinese restaurant franchise, and essentially 
amounts to each group having its own Chinese restaurant with the following modification: whenever 
a customer sits at a new table and orders a dish, that dish itself is sampled from an upper CRP 
common to all restaurants. It is also possible to develop stick-breaking representations of the HDP. 
The HDP has found wide application in the machine learning literature. A common application is 
document modeling, where the location of each atom is a topic, corresponding to some distribution 
over words. Rather than bounding the number of topics, there are an infinite number of topics, with 
io document d having its own distribution over topics (represented by a measure Gt). By tying all the 
11 G@’s together through an HDP, different documents can share the same topics while emphasizing 
12 different topics. 
13 Another application involves infinite-state hidden Markov models, also called HDP-HMM. 
14Recall from Section 29.2 that a Markov chain is parametrized by a transition matrix, with row r 
15 giving the distribution over the next state if the current state is r. For an infinite-state HMM, this is 
16 an infinite-by-infinite matrix, with row r corresponding to a distribution G” with an infinite number 
170f atoms. The different G”’s can be tied together by modeling them with an HDP, so that atoms 
igfrom each correspond to the same states [Fox +08]. 
ig Hierarchical nonparametric models of this kind can be constructed with other measures, such as the 
20 Pitman-Yor process. For certain parameter settings, the PYP possesses convenient marginalization 
21 properties that the DP does not [Woo+09b]. In particular, simulating a random probability measure 
22(RPM) in the following two steps 
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23 

z Go| Ho ~ PYP (0, do, Ho), (31.42) 
25 G1ı|Go ~ PYP(0, di, Go) (31.43) 
is equivalent to directly simulating G; without instantiating Go as below: 

28 Gı|Ho ~ PYP(0, dodi, Ho). (31.44) 
29 


30 This marginalization property (also called coagulation) allows deep hierarchies of dependent RPMs, 
31with only a smaller, dataset-dependent subset of G’s having to be instantiated. In the sequence 
32memoizer of [Woo+11], the authors model sequential data (e.g. text) with hierarchies that are 
33 infinitely deep, but with only a finite number of levels ever having to be instantiated. If needed, 
34intermediate random measures can be instantiated by a dual fragmentation operator. 

35 Deeper hierarchies like those described above allow more refined modeling of similarity between 
36 different groups. Under the original HDP, the groups themselves are exchangeable, with no subset of 
37 groups a priori more similar to each other than to others. For instance, suppose each of the measures 
38G,...,Gr correspond to distributions over topics in different scientific journals. Modeling the G;’s 
39with an HDP allows statistical sharing across journals (through shared clusters and similar cluster 
40 probabilities), but does not regard some journals as a priori more similar than others. If one had 
41 further information, e.g. that some are physics journals and the rest are biology journals, then one 
42 might add another level to the hierarchy. Now rather than each journal having a probability measure 
43 G" drawn from a DP(a, H), physics and biology journals have their own base measures H, and Hy 
44that allow statistical sharing among physics and biology journals respectively. Like the HDP, these 
45are draws from a DP with base measure H. To allow sharing across disciplines, H is again random, 
46 drawn from a DP with base measure Ho. Overall, if there are D disciplines, 1,2,...,D, the overall 
47 
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31.5. THE INDIAN BUFFET PROCESS AND THE BETA PROCESS 


model is 


oe 
H4~DP(a1,H), dé€1,...,D 
4 DP(ao,H%), t€1,...,Ta 
4. Ghd de1,...,D, #€1,...,Ty, @€1,...,N, 


One might add further levels to the hierarchy given more information (e.g. if disciplines are grouped 
into physical sciences, social sciences and humanities). 

Dependent random probability measures can also be indexed by covariates in some continuous 
space, whether time, space, or some Euclidean space or manifold. This space is typically endowed 
with some distance or similarity function, and one expects that measures with similar covariates 
are a priori more similar to each other. Thus, GÉ might represent a distribution over topics in 
year t, and one might wish to model G* evolving gradually over time. There is rich history of 
dependent random probability measures in statistics literature, starting from [Mac99a]. A common 
requirement in such models is that at any fixed time t, the marginal distribution of G* follows some 
well specified distribution, e.g. a Dirichlet process, or a PYP. Approaches exploit the stick-breaking 
representation, the CRP representation or something else like the Poisson structure underlying such 
models (see Section 31.7). 

As a simple and early example, recall the stick-breaking construction from Section 31.2.2, where a 
random probability measure is represented by an infinite collection of pairs (8k, 0%). To construct 
a family of RPMs indexed by some covariate t, we need a family of such sets (3%, 0%) for each of 
the possibly infinite values of t. To ensure each G; is marginally DP distributed with concentration 
parameter a and base measure H, we need that for each t, the Gis are marginally i.i.d. draws from a 
Beta(1, a), and the 6},’s iid. draws from H. Further, we do not want independence across t, rather, 
for two times tı and t2, we want similarity to increase as |t — t2| decreases. To achieve this, define 
an infinite sequence of independent Gaussian processes on T, f(t), k =1,2,..., with mean 0 and 
some covariance function. At any time t, f;,(t) for all k are i.i.d. draws from a normal distribution. 
By transforming these Gaussian processes through the cdf of a Gaussian density (to produce a i.i.d. 
uniform random variables at each t), and then through the inverse cdf of a Beta(1,a)), one has an 
infinite collection of i.i.d. Beta(1,a) random variables at any time t. The GP construction means 
that each 6f varies smoothly with t. A similar approach can be used to construct smoothly varying 
ts that are marginally H-distributed, and together, these form a family of gradually varying RPMs 
Gt, with 


k-1 


= ne where 1, = Gj [[a — 65) (31.49) 


j=1 
Of course, such a model comes with formidable computational challenges, however the marginal 
properties allows standard MCMC methods from the DP and other RPMs to be adapted to such 


settings. 
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(a) (b) 
Figure 31.7: (a) A realization of the cluster matriz Z from the Chinese restaurant process (CRP) (b) A 
realization of the feature matrix Z from the Indian buffet process (IBP). Rows are customers and columns are 


dishes (for the CRP each table has its own dish). Both produce binary matrices, with the CRP assigning a 
customer to a single table (cluster), and the IBP assigning a set of dishes (features) to each customer. 
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1931.5 The Indian buffet process and the Beta process 

20 

21The Dirichlet process is a Bayesian nonparametric model useful for clustering applications, where 
22the number of clusters is allowed to grow with the size of the dataset. Under this generative model, 
23each observation is assigned to a cluster through the variable z;. Equivalently, z; can be written as 
24a one-hot binary vector, and the entire clustering structure can be written as a binary matrix Z 
25 consisting of N rows, each with exactly one non-zero element (Figure 31.7(a)). 

26 Clustering models that limit each observation to a single cluster can be overly restrictive, failing to 
27capture the complexity of the real datasets. For instance, in computer vision applications, rather 
28than assign an image to a single cluster, it might be more appropriate to assign it a binary vector of 
29 features, indicating whether different object types are or are not present in the image. Similarly, in 
30 movie recommendation systems, rather than assign a movie to a single genre (‘comedy’ or ‘romance’ 
31etc.), it is more realistic to assign to multiple genres (‘comedy’ AND ‘romance’). Now, in contrast to 
32all-or-nothing clustering (which would require a new genre ’romantic comedy’), different movies can 
33 have different but overlapping sets of features, allowing a partial sharing of statistical information. 

34 Latent feature models generalize clustering models, allowing each observation to have multiple 
35features. Nonparametric latent feature models allow the number of available features to be infinite 
36(rather than fixed a priori to some finite number), with the number of active features in a dataset 
37 growing with a dataset size. Such models associate the dataset with an infinite binary matrix 
38 consisting of N rows, but now where each row can have multiple elements set to 1, corresponding to 
39the active features. This is shows in Figure 31.7(b). As with the clustering models, each column is 
40 associated with a parameter drawn i.i.d. from some base measure H. 

41 The Indian buffet process (IBP) in a Bayesian nonparametric analogue of the CRP for latent 
42feature models. As with the CRP, the IBP is specified by a concentration parameter a, and a 
43 base measure H on some space ©. The former controls the distribution over the binary feature 
44matrix, whereas feature parameters 6; are drawn i.i.d. from the latter. Under the IBP, individuals 
45enter sequentially into a restaurant, now picking a set of dishes (instead of a single table). The 
46 first customer samples a Poisson(a)-distributed random number of dishes, and assigns each of them 
47 
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0 0.25 0.50 0.75 1 0 0.25 0.50 0.75 1 
(a) (b) 


Figure 31.8: (a) A realization of a Beta process on the real line. The the atoms of the random measure are 
show in red, while the Beta subordinator (whose value at time t sums all atoms up to t) is shown in black. 
(b) Samples from the Beta process 


values drawn from H. When the ith customer enters the restaurant, they first make a pass through 
all dishes already chosen. Suppose N4 of the earlier customers have chosen dish d: then customer 
i selects this with probability Nq/i. This results in a rich-get-richer phenomenon, where popular 
dishes (common features) are more likely to be selected in the future. Additionally, the i customer 
samples a Poisson(a/i) number of new dishes. This results in a non-zero probability of new dishes, 
that nonetheless decreases with i. 

A key property of the Indian buffet process is that, like the CRP, it is exchangeable. In other 
words, its statistical properties do not change if its rows are permuted. For instance, one can show 
that the number of dishes picked by any customer is marginally Poisson(a) distributed. Similarly, the 
distribution over the number of features shared by the first two customers is the same as the that for 
the first and last customer. We mention that like the CRP, the ordering of the dishes (or columns) is 
arbitrary and might appear to violate exchangeability. For instance, the first customer cannot pick 
the first and third dishes and not the second dish, while this is possible for the second customer. 
These artefacts disappear if we index columns by their associated parameters. Equivalently, after 
reordering rows, we can transform the feature matrix to be left-ordered (essentially, all new dishes 
selected by a customer must by adjacent, see [GG11]), and we can view the IBP as a prior on such 
left-ordered matrices. 

The exchangeability of the rows of the IBP implies via de Finetti’s theorem that there exists 
an underlying random measure G, conditioned on which the rows are i.i.d. draws. Just as the 
Chinese restaurant process represents observations drawn from a Dirichlet process-distributed random 
probability measure, the dishes of each customer in the IBP represent observations drawn from a 
Beta process. Like the DP, the Beta process is an atomic measure taking the form 


G(0) = X 1x60, (8). (31.50) 
k=1 


Each of the z;’s lie between 0 and 1, but unlike the DP where they add up to 1, the 7;’s sum up 
to a finite random number. The Beta process is thus a random measure, rather than a random 
probability measure. One can imagine the kth atom as a coin located at 0k, with probability of 
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success equal to mk. To simulate a row of the IBP, one flips each of the infinite coins, selecting the 
feature at 0; if the k coin comes up heads. One can show that if the 7;’s sum up to a finite number, 
the number of active features will be finite. Of the infinite atoms in G, a few will dominate the rest, 
these will be revealed through the rich-gets-richer dynamics of the IBP as features common to a large 
proportion of the observations. 

The Beta process has a construction similar to the stick-breaking representation of the Dirichlet 
process. As with the DP, the locations of the atoms are independent draws from the base measure, 
while the sequence of weights 71,72,... are constructed from an infinite sequence of Beta variables: 
ionow these are Beta(a, 1) distributed, rather than Beta(1, œ) distributed. The overall representation 
110f the Beta process is then 
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By, ~ Beta(a,1), 0,~ 4H, (31.51) 


k 
Tk = Prt = [| 4 (31.52) 
i=1 
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It is not hard to see that under the IBP, the total number of dishes underlying a dataset of size 
-7 N follows a Poisson(aHy) distribution, where Hy is Hy = DAM 1/i is the Nth harmonic number. 
99 Lhe Beta process and the IBP have been generalized to three-parameter versions allowing power-law 
9, Pehavior in the total number of dishes (features) in a dataset of size N, as well as in the number 
gt customers trying each dish [TGO09]. It has found application in tasks ranging from genetics, 
5, collaborative filtering, and in models for graph and graphical model structures. Just as with the 
94 DP, posterior inference can proceed via MCMC (exploiting either the IBP or the stick-breaking 
on representation), particle filtering or using variational methods. 


26 

2731.6 Small-variance asymptotics 

28 

29 Nonparametric Bayesian methods can serve as a basis to develop new and efficient discrete optimization 
30algorithms that have the flavor of Lloyd’s algorithm for the k-means objective function. The starting 
31point for this line of work is the view of the k-means algorithm as the small-variance asymptotic 
32limit of the EM algorithm for a mixture of Gaussians. Specifically, consider the EM algorithm to 
33 estimate the unknown cluster means p = (f11,..., Hk) of a mixture of k Gaussians, all of which have 
34the same known covariance o7J. In the limit as o? — 0, the E-step, which computes the cluster 
35 assignment probabilities of each observation given the current parameters ws“), now just assigns each 
36 observation to the nearest cluster. The M-step recomputes a new set of parameters w+!) given the 
37cluster assignment probabilities, and when each observation is hard-assigned to a single cluster, the 
38cluster means are just the means of the assigned observations. The process of repeatedly assigning 
39 observations to the nearest cluster, and recomputing cluster locations by averaging the assigned 
40 observations is exactly Lloyd’s algorithm for k-means clustering. 

41 To avoid having to specify the number of clusters k, [KJ12b] considered a Dirichlet process mixture 
42of Gaussians, with all infinite components again having the same known variance a”. The base 
43measure H from which the component means are drawn was set to a zero-mean Gaussian with 
44variance p°, with a the concentration parameter. The authors then considered a Gibbs sampler for 
45this model, very closely related to the sampler in Algorithm 43, except that instead of collapsing or 
46 marginalizing out the cluster parameters, these are instantiated. Thus, an observation x; as assigned 
47 
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31.6. SMALL-VARIANCE ASYMPTOTICS 


to a cluster c with mean ue with probability proportional to Ne,—i Tat exp( sts |v; — uell”), while 


morr Pl- arn lel’). 
After cycling through all observations, one then resamples the parameters of each cluster. Following 
the Gaussian base measure and Gaussian likelihood, this too follows a Gaussian distribution, whose 
mean is a convex combination of the prior mean 0 and data-driven term, specifically the average of 
the assigned observations. The weight of the prior term is proportional to the inverse of the prior 
variance p?, while the weight of the likelihood term is proportional to the inverse of the likelihood 
variance o?. 

To derive a small-variance limit of the sampler above, we cannot just let a? go to 0, as that 
would result in each observation being assigned to its own cluster. To prevent the likelihood from 
dominating the prior in this way, one must also send the concentration parameter a to 0, so that the 
DP prior enforces an increasingly strong penalty on a large number of clusters (recall that a priori, 
the average number of clusters underlying N observations is alog N). [KJ12b] showed that with a 
scaled as a = (1 + p?/a?)4/? exp(— 32) for some parameter À, and taking the limit o? — 0 with p 
fixed, we get the following modification of the k-means algorithm: 


it is assigned to a new cluster with probability proportional to a 


1. Assign each observation to the nearest cluster, unless the distance to the nearest cluster exceeds X, 
in which case assign the observation to its own cluster. 


2. Set the cluster means equal to the average of the assigned observations. 


Algorithm 45: DP-means hard clustering algorithm for data D = {z1,..., £N} 
1 Initialize the number of clusters K = 1, with cluster parameter 4, equal to the global mean: 
N 
Pa N Vint ti 
2 Initialize all cluster assignment indicators z; = 1 
3 while all zis have not converged do 


4 for each i = 1 : N do 

5 Compute distance dj, = ||x; — ugl]? to cluster k for k =1,...,K 
6 if min, dik > À then 

7 Increase the number of clusters K by 1: K= K +1 

8 L Assign observation 7 to this cluster: z; = K and ug = zi 

9 else 
10 Set z; = arg min dik 
11 for each k = 1 : K do 
12 Set Dy be the observations assigned to cluster k. Compute cluster mean 

— i 
Hk = Ty] ved, T 


Like k-means, this is a hard-clustering algorithm, except that instead of having to specify the 
number of clusters k, one just specifies a penalty A for introducing a new cluster. The actual number 
of clusters is determined by the data, [KJ12b] refer to this algorithm as DP-means. One can show that 
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the iterates above converge monotonically to a local maximum of the following objective function: 
x 1 
>o De le-a? +AK, where pm = g DO e (31.53) 
k=12€Dp Pal Dy 


The first term in the expression above is exactly the objective function of the k-means algorithm with 
K clusters. The second term is an penalty term that introduces a cost À for each additional cluster. 
Interestingly, the penalty term above corresponds to the so called Akaike Information Criterion 
o(AIC), a well studied approach to penalizing model complexity. 
11 Út is also possible to derive hard-clustering algorithms for simulataneously clustering multiple 
2 datasets, while allowing these to share clusters. This is possible through the small-variance limit of a 
3 Gibbs sampler for the hierarchical Dirichlet process (HDP) and results in a clustering algorithm that 
4now has two thresholding parameters, a local one À; and a global one Ag. The algorithm proceeds by 
5 maintaining a set of global clusters, with the local clusters of each dataset assigned to a subset of the 
global clusters. It then repeats the following steps until convergence: 
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47 Assign observations to local clusters For Tij, the ith datapoint in dataset j, compute the 
distance to all global clusters. For those global clusters not currently present in dataset j, add A; 


+2 to their distance; this reflects the cost of introducing a new cluster into dataset j. Now, assign zij 
20 to the cluster with the smallest distance, unless the smallest distance exceeds Ag + Ai, in which 
21 case, create a new global cluster and assign Tij to it. Observe that in the latter case, the distance 
22 of x;; to the new cluster is 0, with A, + à reflecting the cost of introducing a new global and 
23 then local cluster. 

24 


25 Assign local clusters to global clusters For each local cluster l, compute the sum of the dis- 
26 tances of all its assigned observations to the cluster mean. Call this dı. Also compute the sum of 
27 the distances of the assigned observations to each global cluster. For global cluster p, call this d; p- 
23 Then assign local cluster / to the global cluster with the smallest dı p, unless min dı p > dj + Ag in 
29 which case we create a new global cluster. 


30 
~-Recompute global cluster means Set the global cluster means equal to the average of the as- 
signed observations across all datasets. 


33 In [JKJ12], these ideas were extended from DP mixtures of Gaussians to DP mixtures of more 
34general exponential family distributions. Briefly, the hard-clustering algorithms maintained the same 
35structure, the only difference being that distance from clusters was measured using a Bregman 
36 divergence specific to that exponential family distribution. For Gaussians, the Bregman divergence 
37reduces to the usual Euclidean distance. 

38 In [BKJ13], the authors showed how such small-variance algorithms could be derived directly 
39from a probabilistic model, and independent of any specific computational algorithm such as EM 
4oor Gibbs sampling. Their approach involves computing the MAP solution of the parameters of the 
41model, and then taking the small-variance limit to obtain an objective function. This approach, 
42called MAP-based Asymptotic Derivations from Bayes or MAD-Bayes allowed them to derive, among 
43 other things, an analog of the DP-means algorithm from feature-based models. The called this the 
44BP-means algorithm, after the Beta process underlying the Indian Buffet process. 

45 Write X for an N x D matrix of N D-dimensional observations, Z for an N x K matrix of binary 
46 feature assignments, and A fora K x D matrix of K D-dimensional features, with one seeking a pair 
47 
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31.7. COMPLETELY RANDOM MEASURES 


tele tl T 


Figure 31.9: Poisson process construction of a completely random measure G = >>, wide,. The set of pairs 
{(w1, 01), (we, 02),...} is a realization from a Poisson process with intensity (0, w) = H(0)7(w). 


(A, Z) such that ZA approximates X as well as possible. [BKJ13] showed in the small-variance limit, 
finding the MAP solution for the IBP is equivalent to the following problem: 


argming z atrace[(X — ZA)'(X — ZA)] + AK, (31.54) 


where again A is a parameter of the algorithm, govering how the concentration parameter scales with 
the variance, and specifying a penalty on introducing new features. This objective function is very 
intuitive: the first term corresponds to the approximation error from K features, while the second 
term penalizes model complexity resulting from a large number of features K. This objective funtion 
can be optimized greedily by repeating three steps for each observation 7 until convergence: 


1. Given A and K, compute the optimal value of the binary feature assignment vector z; of observation 
i and update Z. 


2. Given Z and A, introduce an additional feature vector equal to the residual for the ith observation, 
namely, x; — 2;A. Call A’ the updated feature matrix, and Z’ the updated feature assignment 
matrix, where only observation į has been assigned this feature. If the configuration (K +1, Z’, A’) 
results in a lower value of the objective function than (K,Z,A), set the former as the new 
configuration. In other words, if the benefit of introducing a new feature outweights the penalty 
A, then do so. 


3. Update the feature vectors A given the feature assignment vectors Z. 


[BKJ13] showed that this algorithm monotonically decreases the objective in Equation (31.54), 
converging eventually to a local optimum. Subsequent works have considered the small-variance 
asymptotics of more structured models, such as topic models, hidden Markov models and even 
continuous-time stochastic process models. 


31.7 Completely random measures 


The Dirichlet process is a example of a random probability measure, a class of random measures 
which always integrate to 1. The Beta process, while not an RPM, belongs to another class of random 
measures: completely random measures (CRM). A completely random measure G satisfies the 
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following property: for any two disjoint subsets Tı and T> of ©, the values G(T) and G(T») are 
independent random variables: 


G(T) L G(T2) YT, hb CO s.t.T,NToh =Ø. (31.55) 


Note that the Dirichlet process is not a CRM, since for disjoint sets T and its complement T° = O\T, 
we have G(T) = 1 — G(T°), which is as far from independent as can be. More generally, under the 
DP, for disjoint sets Tı and Tz, the measures G(T) and G(T) are negatively correlated. We will see 
later though that the Dirichlet process is closely related to another CRM, the Gamma process. As a 
side point, beyond the sum-to-one constraint, G(T) does not tell us anything about the distribution 
of probability within G(T), making the DP what is known as a neutral process. 

12 The simplest example of a CRM is the Poisson process (see Section 31.9). A Poisson process 
18with intensity \(@) is a point process producing points or events in O, with the number of points 
in any set T following a Poisson( fp \(0)d0)-distribution, and with the counts in any two disjoint 
45sets independent of each other. While it is common to think of this as a point process, one can also 
16think of a realization from a Poisson process as an integer-valued random measure, where the 
17 measure of any set is the number of points falling within that set. It is clear then that the Poisson 
48 process is an example of a CRM. 

19 Tt turns out that the Poisson process underlies all completely random measures in a fundamental 
20way. For some space ©, and W the positive real line, simulate a Poisson process on the product space 
21W x © with intensity A(6, w). Figure 31.9 shows a realization from this Poisson process, write it as 
22 M = {(01, w1),---,(Ojmj, Wjasj)} where |M] is the number of events. This can be used to construct 
“San atomic measure G = pp ae w;09, as illustrated in Figure 31.9. From its Poisson construction, this 
2515 a completely random measure on O, with set T € © having measure 5°), wl (0; € T). Different 
= settings of A(0, w) give rise to different CRMs, and in fact, other than CRMs with atoms at some 
97 xed locations in O, this construction characterizes all CRMs. 

= For a CRM, the Poisson intensity is typically chosen to factor as A(0,w) = H(0)7(w), with 
Jo H(0)d0 = 1. Then, H (0) is the base measure controlling the locations of the atoms in the CRM, 
= while the measure y(w) controls the number of atoms, and the distribution of their weights. Setting 
~(w) = w t(1 — w)®! gives the Beta process with base measure H(0). Other choices include 


qo tae Gamma process (7(w) = aw! exp(—w)), the stable process (y(w) = Tarayow t), and the 
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33 generalized Gamma process (y(w) = awt exp(—Cw)). 


Tao) 
34 For all three processes described earlier, the y(w) integrates to infinity, so that fo fy, A(0, w)d0dw = 
3500. Consequently, the number of Poisson events, and thus the number of atoms in the CRMs are 
36infinite with probability one. At the same time, mass of the y(w) function is mostly concentrated 
37around 0 (see Figure 31.9), and for any e€ > 0, [~~ y(w)dw is finite. It is easy to show that the sum 
33 0f the w’s is finite almost surely. Call this sum W, then the first condition ensures W greater than 0, 
39 while the second ensures it is finite. These two conditions make it sensible to divide a realization 
4o0f a CRM by its sum, resulting in a random measure that integrates to 1: a random probability 
4ı measure. Such RPMs are called normalized completely random measures, or sometimes just 
a2normalized random measures (NRMs). The Dirichlet process we saw earlier is an example of 
43an NRM: it is a normalized Gamma process. This result mirrors the situation with the finite Dirichlet 
a4distribution: one can simulate from a d-dimensional Dir(a,...,aq¢) distribution by first simulating 
45d independent Gamma variables gi ~ Ga(a;,1),i = 1,...,d, and then defining the probability 
46 vector Fey E ..+;9a). The Pitman-Yor process is not an NRM except for special settings of its 
47 
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31.8. LEVY PROCESSES 
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Figure 31.10: (a) A realization of a Brownian motion (b) A realization of Cauchy process (an a-stable process 
with a = 1). 


parameters: like we saw, d = 0 is a Dirichlet process or normalized Gamma process. a = 0 is a 
normalized stable process. The normalized generalized Gamma process is an NRM that includes 
the DP and the normalized stable process as special cases. 


31.8 Lévy processes 


Completely random measures are also closely related to Lévy processes and Lévy subordina- 
tors [Ber96]. A Lévy process is a continuous-time stochastic process {L;}:>09 taking values in some 
space (e.g. R?) that satisfies two properties!: 


stationary increments: for t, A > 0, the random variable La — L does not depend on t 
independent increments: for t, A > 0, Lipa — L is independent of values before t. 


A Lévy subordinator is a real-valued, nondecreasing Lévy process. If we drop the stationarity 
condition, it should be clear that the increments of a Lévy subordinator are exactly the atoms 
of a completely random measure (see Figure 31.8). Common examples of Lévy subordinators are 
Beta subordinators (or Beta processes), Gamma subordinators, stable subordinators and generalized 
Gamma subordinators. CRMs generalize Lévy subordinators, by allowing them to be indexed by 
some general space © (rather than by nonnegative reals), and by relaxing the stationarity condition 
to allow the atoms to follow some general base measure H. 

Unlike Lévy subordinators, a general Lévy process can have negative changes as well, with the 
change Li, — Li belonging to an infinitely divisible distribution, scaled by A. A random 
variable X follows an infinitely divisible distribution if, for any positive integer N, there exists some 
probability distribution such that the sum of N i.i.d. samples from that distribution has the same 
distribution as X. Examples of infinitely divisible distributions include the Poisson distribution, the 
Gamma distribution, the Cauchy distribution, the a-stable distribution, and the inverse-Gaussian 
distribution (among many others), though by far the most well-known example is the Gaussian 
distribution. It is easy to see why the change in a Lévy process Lia — Ly over some time interval 


1. There is also a technical continuity condition that we do not discuss here. 
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A follows an infinitely divisible distribution: just divide the interval into N equal-length segments. 
From the properties of the Lévy process, the changes over these segments are independent and 
identically distributed, and their sum equals L,1,,—L,;. The Lévy-Kintchine formula shows that the 
converse is also true: any infinitely divisible distribution represents the change of an associated Lévy 
process. The Lévy process corresponding to the Gaussian distribution is the celebrated Brownian 
motion (or Weiner process). Brownian motion is a fundamental and widely applied stochastic 
process, whose mathematics was first studied by Louis Bachelier to model stock markets, and later, 
famously, by Albert Einstein to argue about the existence of atoms. For a Brownian motion, the 
ioincrement Li +a — Lz, follows a normal N(uA, oA) distribution, where u and ø are the drift and 
11 diffusion coefficients. Setting and o to 0 and 1 gives standard Brownian motion. Paths sampled 
i2from the Weiner process are continuous with probability one, all other Lévy processes are jump 
13 processes. Figure 31.10 shows realizations from some processes. The jump processes have a Poisson 
i4construction related to the Lévy subordinators and completely random measures, and the Lévy-It6 
15 decomposition shows that every Lévy process can be decomposed into Brownian and Poisson 
1g components. Levy processes have been widely applied in mathematical finance to model asset prices, 
7insurance claims, stock prices and other financial assets. 
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31.9 Point processes with repulsion and reinforcement 


21 F : : ; 
32 this section, we look more closely at the Poisson process, as well as other, more general point 
~ processes that allow inter-event interactions. 

23 


24 
2531.9.1 Poisson process 


We have already briefly seen the Poisson process: this is a point process on o space © that is 
~ parametrized by an intensity function A(0) > 0, and that produces a Poisson( f- \(@)d0)-distributed 
Z aumber of points of events in any set T € O, with the counts in any two disjoint sets independent 
30f each other. Recall that if N; ~ PoisonlA) are independent, then a well-known property of the 
31 Poisson distribution is that Nı + N2 ~ Poisson(A; + A2). This relates to the infinite divisibility 
», Of the Poisson distribution. It is clear that the average number of points is large in areas where 
33 (9) is high, and small where A(@) is small. When the intensity \(@) is some constant À, we have a 
34 34 Homogeneous Poisson etry otherwise we have an inhomogeneous Poisson process. 
z- Depending on whether Ja AC $)d@ is infinite or finite, a Poisson process will either produce an 
3g ifinite number of points a the Poisson process underlying the CRMs of Section 31.7) or a 
37 finite number of points. The latter is more common in applications, such as modeling phone calls 
3g 0! financial shocks (© is some finite time-interval), the locations of trees or forest fires (O is some 
39 Subset of the Euclidean plane), the locations of cells or galaxies (O is a 3-dimensional space) or 
qo events in higher-dimensional spaces (for example, spatio-temporal activity). One way to simulate a 
-finite Poisson E is to first simulate the number of total number of points N, which follows a 


4p Poisson( (fg AC 0)d0) distribution. One a then simulate the locations of these points by sampling N 
43 is from the probability density TO 4 a qq: For a homogeneous Poisson process, the locations are 
44uniformly distributed over ©. One can also easily simulate a rate-A homogeneous Poisson on the real 
45 line by exploiting the fact that inter-event times follow an exponential distribution with mean 1/,. 
46 If the integral Jo A(0)d0 is difficult to evaluate, one can also simulate a Poisson process using the 
47 
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31.9. POINT PROCESSES WITH REPULSION AND REINFORCEMENT 


(a) (b) (c) 


Figure 31.11: A realization of (a) a homogeneous Poisson process. (b) an underdispersed point process 
(Swedish pine sapling locations [Rip05]). (c) A realization of an ovderdispersed point process (California 
redwood tree locations [Rip77]). 


thinning theorem [LS79]. Here, one needs to find a function q(0) such that y(0) > A(0), Y8, and such 
that it is easy to simulate from a rate-y(@) Poisson process. Suppose the result is Y = {u1,..., YN}. 
Since y(@) > A(0), Y is going to contain more events, and one thins Y by keeping each element y; in 
W with probability A(7;)/7(a#:) (otherwise one discards it). Once can show that the set of surviving 
points is then a realization of a Poisson process with rate A(0). 

The defining feature of the Poisson process is the assumption of independence among events. In 
many settings, this is inappropriate and unrealistic, and the knowledge of an event at some location 
0 might suggest a reduced or elevated probability of events in neighboring areas. Point processes 
satisfying the former property are called underdispersed (Figure 31.11(b)), while point process 
satisfying the latter property are called overdispersed (Figure 31.11(c)). Examples of underdispersed 
point processes include the locations of trees or train stations, which tend to be more spread out than 
Poisson because of limited resources. An example of an overdispersed point process is earthquake 
locations, where aftershocks tend to occur in the vicinity of the main shock. 

A simple approach to modeling overdispersed point processes is through hierarchical extensions of 
the Poisson process that allow the Poisson intensity function A(-) to be a random variable. Such 
models are called doubly-stochastic Poisson processes or Cox processes [CI80], and a common 
approach is to model the intensity X(-) via a Gaussian process. Note though that the Poisson process 
intensity function must be nonnegative, so that A is often a transformed Gaussian process: 


0(0) ~ GP. (31.56) 


Common examples for g include exponentiation, sigmoid transformation or just thresholding. Because 
of the smoothness of the unknown X(-), observing an event at some location suggests events are 
likely in the neighborhood. Such models still do not capture direct interactions between point 
process events, rather, these are mediated through the unknown intensity functions, making them 
inappropriate for many applications. For instance, in neuroscience a neuron’s spiking can be driven 
directly by past activity, and activity of other neurons in the network, rather than just through some 
shared stimulus X(t). Similarly, social media activity has a strong reciprocal component, where, for 
instance, emails sent out by a user might be in response to past activity, or activity of other users. 
The next few subsections show how one might explicitly model such interactions. 
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(a) (b) 


Figure 31.12: (a) A self-exciting Hawkes process. Each event causes a rise in the event rate, which can lead 
to burstiness. (b) A pair of mutually exciting Hawkes processes, events from each cause a rise in the event 
rate of the other. 
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1831.9.2 Renewal process 
19 


20 Renewal processes are one class of models of repulsion and reinforcement for point processes 
21 defined on the real line (typically regarded as time). Recall that for a homogeneous Poisson process, 
22inter-event times follow an exponential distribution. The exponential distribution has the property of 
23memorylessness, where the time until the next event is independent of how far in the past the last 
24event occurred. That is, if 7 follows the exponential distribution, then for any 6 and A > 0, we have 
25 p(t > A +ô|r > 6) = p(T > A)). Renewal processes incorporate memory by allowing the interevent 
26 times to follow some general distribution on the positive reals. Examples include Gamma renewal 
27 processes and Weibull renewal processes, where interevent times follow the Gamma and Weibull 
28 distribution respectively. Both these processes include parameter settings that recover the Poisson 
29 process, but also allow burstiness and refractoriness. Burstiness refers to the phenomenon where, 
30after an event has just occured, one is more likely to see more events than a longer time afterwards. 
31 This is useful for modeling email activity for instance. Refractoriness refers to the opposite situation, 
32where an event occuring implies new events temporarily less likely to occur. This is useful to model 
33neural spiking activity, for instance, where after spiking, a neuron is depleted of resources, and 
34requires a recovery period before it can fire again. 

35 


<31 .9.3 Hawkes process 


38 Hawkes processes [Haw71] are another class of reinforcing point processes that have attracted 
39much recent attention. Hawkes processes provide an intuitive framework for modeling reinforcement 
40in point processes, through self-excitation when a single point process is involved, and through 
41 mutual excitation (sometimes called reciprocity) when collections of point processes are under study. 
42The former, shown in Figure 31.12(a), is relevant when modeling bursty phenomena like visits to 
43a hospital by an individual, or purchases and sales of a particular stock. The latter, displayed in 
44Figure 31.12(b) is useful to characterize activity on social or biological networks, such as email 
45communications or neuronal spiking activity. In both examples, each event serves as a trigger for 
46 subsequent bursts of activity. This is achieved by letting A(t), the event rate at time t, be a function 
47 
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31.9. POINT PROCESSES WITH REPULSION AND REINFORCEMENT 


of past activity or event history. Write H(t) = {tp : tk < t} for the set of event times up to time t. 
Then the rate at time t, called the conditional intensity function at that time, is given by 


AIHE) = y+ S t-t), (31.57) 


tkEH(t) 


where y is called the base-rate and the function ¢(-) is called the triggering kernel. The latter 
characterizes the excitatory effect of a past event on the current event rate. Figure 31.12 shows 
the common situation where ¢(-) is an exponential kernel ¢(A) = Be~*/7, A > 0. Here, a new 
event causes a jump of magnitude 8 in the intensity A(t), with its excitatory influence decaying 
exponentially back to y, with 7 the time-scale of the decay. For a multivariate Hawkes process, there 
are m point processes (N1 (t), No(t),--- ,Nm(t)) associated with m users or nodes. Write H;(t) for 
the event history of the ith process at time t. Its conditional intensity can depend on all m event 
histories, taking the form 


Ai(tl{H; (t YF = Vi + > 5 big ( (t = tk). (31.58) 


j=1 ty H; (t) 


More typically, the conditional intensity of user i can depend only on the event histories of those 
nodes ‘connected’ to it, with connectivity specified by a graph structure, and with ¢,,(-) = 0 if there 
is no edge linking i and j. Alternately, events can have marks indicating whom they are sent to (e.g. 
receipients of an email), with each event only updating the conditional intensities of its recipients. 

Simulating from a Hawkes process is a fairly straightforward extension of simulating from an 
inhomogeneous Poisson process. Consider the univariate Hawkes process, and suppose the last event 
occured at time t;. Then, until the next event occurs, at any time t > t;, we have that H(t) = H(t,), 
so that the conditional intensity A(¢t/H(t)) = A(t|/H(t,)). The next event time, t;+1, is then just the 
time of the first event of a Poisson process with intensity A(¢t|/H(¢;)) on the interval [t;, o0). With most 
choices of the kernel ¢, tj41 can easily be simulated, either by integrating A(t|/H(t;)), or by Poisson 
thinning. At time t;41, the history is updated to incorporate the new event, and the conditional 
intensity experiences a jump. The updated intensity is used to simulate the next event at some time 
ti+2 > ti41 from a rate-A(t|H(ti+1)) Poisson process, and the process is repeated until the end of the 
observation interval. For the case of multivariate Hawkes processes, one has a collection of competing 
intensities A;(t|{H,(t)}74,). The next event is the first event among all events produced by these 
intensities, after which the intensities are updated and the process is repeated. A realization Y from 
a Hawkes process has log-likelihood 


= SO log M(t" H(t" D- fx t|H(t) (31.59) 


This can typically be evaluated quite easily, so that maximum likelihood estimates of parameters like 
the base-rate as well as parameters of the excitation kernel can be obtained straightforwardly. 

The Hawkes process as described is a fairly simple model, and there have been a number of 
extensions enriching its structure. An early example is [BBH12], where the authors considered the 
multivariate Hawkes process, now with an underlying clustering structure. Instead of each individual 
point process having its own conditional intensity function, each cluster has an intensity function 
shared by all point process assigned to it. The interaction kernels are also defined at the cluster level, 
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with an event in process į causing a jump @¢¢,¢(T) in the intensity function of cluster c (where c; is the 
cluster point process 7 belongs to). The cluster structure was modeled through a Dirichlet process, 
allowing the authors to learn the underlying clustering, as well as the inter-cluster interaction kernels 
from interaction data. In [Tan+16], the authors considered marked point processes, where event i at 
time t; has some associated content y; (for example, each event is a social media post, and y; is the 
text associated with the post at time t;). The authors allowed the jump in the conditional intensity 
to depend on the associated mark, with the Hawkes kernel taking the form ¢(A) = f(y;) exp(—A/r). 
In that work, the authors modeled the function f with a Gaussian process, though other approaches, 
1oSuch as ones based on neural networks, are possible. 

11 The neural Hawkes process [ME17] is a more fleshed out approach to modeling point processes 
i2using neural networks. This models event intensities through the state of a continuous-time LSTM, a 
13modification the more standard discrete-time LSTMs from Section 16.3.4. Central to an LSTM is a 
ia4memory cell c; to store long-term memory, summarizing the past until time step 7. Continuous-time 
15 LSTMs include two long-term memory cells, c; and ¢;, summarizing the history until the ith Hawkes 
igevent. The first cell represents the starting value to which the intensity jumps after the 2’th event, and 
izthe second represents a baseline rate after the tth event. These are both updated after each event, 
igwith c(t), the instantaneous rate at any intermediate time determined by c; decaying exponentially 
ig towards the baseline č;. This mechanism allows the intensity at any time to be influenced not just 
20 by the number of events in the past, but also the waiting times between them. It can be extended 
21 to marked point processes, where each event is also associated with a mark y: now both a learned 
22embedding of the mark as well as the time since the last event is used to update state and long-term 
23memory. For more details, see also Du et al. [Du+ 16]. 

24 

2531.9.4 Gibbs point process 
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> Gibbs point processes [MW07| from the statistical physics and spatial statistics literature provide 
282 general framework for modeling interacting point processes on higher-dimensional spaces. Such 
2g Spaces are more challenging than the real line, since there is now no ordering of points, and thus 
3020 natural notion of history affecting future activity. Instead, Gibbs point processes use an energy 
3, function E to quantify deviations from a Poisson process with rate 1. Specifically, under a Gibbs 
32 Process, the probability density of any configuration Y with respect to a rate-1 Poisson process takes 
33 the form 


1 
=  Ps(¥) = 5 exp(-BE)) (31.60) 
35 B 
36 where b is the inverse-temperature parameter, and Zg = E,[exp(—fE(W))] is the normalization 


| constant (the expectation is with respect to 7, the unit-rate Poisson process). Under some conditions 
3908 the energy function E, the expectation Zg is finite, and Pg is a well-defined density, whose integral 
go Mth respect to 7 is 1. While the above equation resembles a Markov random field, the domain of Y 
a much more complicated, and evaluating Zg now involves solving an infinite dimensional integral. 
E Equation (31.60) states that configurations Ų for which E(W) is small are more likely than under 
432 Poisson process, with 8 controlling how peaked this is. The most common energy functions are 
4, Pairwise potentials, taking the form 


46 E= D> gls- s'l), (31.61) 
46 (s,s/)EU 
47 
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31.9. POINT PROCESSES WITH REPULSION AND REINFORCEMENT 


where the summation is over all pairs of events in Y, and 6: Rt — RU œ. The Strauss process is a 
specific example, with energy function specified by positive parameters a and R as 


E(¥)= X` a-I(lls-s'|| < R). (31.62) 
s,s' EW 


This is a repulsive process that penalizes configurations with events separated by distance less than 
R, and as a > oo becomes a hardcore repulsive process, forbidding configurations with two 
points separated by less than R. More generally, the energy function can be piecewise-constant, 
parametrized by a collection of pairs (a1, R1), ..., (an, Rn): 


EY) = $ Joa- Ils- s'l < Ri). (31.63) 


s,s/EW i=1 


Another natural option is to use smooth functions like a squared exponential kernel. Gibbs point 
processes can also involve higher-order interactions, examples being Geyer’s triplet point process 
(which penalizes occurrences of 3 events that are all within some distance R), or area-interaction 
point processes (that center disks of radius R on each event, calculate the area of the union of these 
disks, and define E as some function of this area). 

While Gibbs point processes are flexible and interpretable point process models, the intractable 
normalization constant Zg makes estimating parameters like a formidable challenge. In practice, 
these models have to be fit using approximate approaches such as maximizing a pseudo-likelihood 
function (instead of Equation (31.60)). 


31.9.5 Determinantal point process 


Determinantal point processes or DPPs or DPPs are another approach to modeling repulsion, 
and have seen considerable popularity in the machine learning literature (see e.g., ((KT+12; Mac75; 
Bor; LMR15]). Like any point process, a DPP is a probability distribution over subsets of a fixed set 
S, and in the DPP literature this is often called the ground set. Point process applications typically 
have S with uncountably infinite cardinality (for example, S could be the real line), although machine 
learning applications of DPPs often focus on S with a finite number of elements. For instance, S 
could be a database of images, a collection of news articles, or a group of individuals. A sample 
from a DPP is a random subset of S, produced, for instance, in response to a search query. The 
repulsive nature of DPPs ensures diversity and parsimony in the returned subset. This could be 
useful, for example, to ensure responses to a search query are not minor variations of the same image. 
Another application is clustering, where a DPP serves as a prior over the number of clusters and their 
locations, with the repulsiveness discouraging redundant clusters that are very similar to each other. 

DPPs are parametrized by a similarity kernel K, whose element K;; gives the similarity between 
elements 7 and j of the ground set. For finite ground sets, K is just a similarity matrix. DPPs require 
K to be positive definite, with largest eigenvalue less than 1, that is 0 < K < I. For simplicity, K is 
often assumed symmetric, though this is not necessary. Given a kernel K, the associated DPP is 
defined as follows: if Y is the random subset of S drawn from the DPP, then the probability that Y 
contains any subset A of S is given by 


p(A CY) = det(K,a). (31.64) 
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Figure 31.13: The determinant of a matrix V is the volume of the parallelogram spanned by the columns of V 
and the origin. 
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Here K4 is the submatrix obtained by restricting K to rows and columns in A, and we define 
det(y) = 1 for the empty set ¢. Observe that this probability is specified exactly, and not just up 
,t0 a normalization constant. We immediately see that Y contains the empty set with probability one 
~ (this is trivially true since the empty set is a subset of every set). We also see that the probability 
99 that element 2 is selected in Y is the ith diagonal element of K: 


In ls 


22 

23so that the diagonal of K gives the inclusion probabilities of the individual elements in S. More 
24interestingly, the probability a pair of elements {i,j} are both contained in Y is given by 

25 


2T The first term K;;Kjj is the probability of including i and j if they were independent, and this 
99 Probability is adjusted by subtracting Ki,, a measure of similarity between i and 7. This is the 
30 28 source of repulsiveness or diversity in DPPs. More generally, the determinant of a set of vectors is 
~ the volume of the parallelogram spanned by them and the origin (see Figure 31.13), and by making 
za the probability of inclusion proportional to the volume, DPPs encourage diversity. We mention for 
4g completeness that for uncountable ground sets like a Euclidean space, the determinant det( K4) now 
gives the product density function of a realization A. This generalizes the intensity of a Poisson 
gp Process to account for interactions between point process events, and we refer the interested reader 
36° Lavancier, Møller, and Rubak [LMR15] for more details. 
Equation (31.64) characterizes the marginal probabilities of a determinantal point process: what is 
-~ the probability that a realization Y contains some subset of S. More often, one is interested in the 
zg Probability that the realization Y equals some subset of S. The latter is of interest for simulation, 
P 32 inference and parameter learning with DPPs. For this, it is typical to work with a narrower class of 
1 PPPs called L-ensembles [Bor; KT+12]. Like general DPPs, such a process is characterized by a 
T Positive semidefinite matrix L, with the probability that Y equals some configuration A given by: 


43 p(Y = A) x det(La). (31.67) 
44 


45 Note that this probability is specified only upto a normalization constant, and so we do not need 
46to upperbound eigenvalues of £. In fact, it is not hard to show that the normalizer is given by 
47 
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31.9. POINT PROCESSES WITH REPULSION AND REINFORCEMENT 


(I + det(L)), so that 


det(L4) 


p(Y = A) = T +det(L)’ 


(31.68) 


A similar calculation can be used to show that p(A C Y) = det(K), where K = L([+L)7! = 
I —(1+1L)~', showing that L-ensembles are indeed a special kind of DPP. Equation (31.68) and 
Equation (31.64) allow parameters of the DPP to be estimated from realizations of a point process, 
typically by gradient descent. Without additional structure, naively calculating determinants is 
cubic in the cardinality of S, and this represents a substantial saving when one considers the number 
of possible subsets of S. When even cubic scaling is too expensive, a number of approximation 
approaches can be adopted, and these are often closely related to approaches to solve the cubic cost 
of Gaussian processes. 

So far, we have only discussed how to calculate the probability of samples from a DPP. Simulating 
from a DPP, while straightforward, is a bit less intuitive, and we refer the reader to [LMR15; KT+12]. 
At a high level, these approaches use the eigenstructure of K to express a DPP as a mixture of 
determinantal projection point processes. The latter are DPPs whose similarity kernel has 
binary eigenvalues (either 0 or 1) and are easier to sample from. Observe that any eigenvalue A; of 
the similarity kernel K must lie in the interval [0,1]. This allows us to generate a random similarity 
kernel K with the same eigenvectors as K but with binary eigenvalues as follows: for each 7, replace 
eigenvalue A; of K with a binary variable Îi simulated from a Bernoulli(\;) distribution. One can 
show that the DPP simulated from K after simulating K from K in this fashion is distributed 
exactly as a DPP with kernel K. We refer the reader to [LMR15] for details on how to simulate a 
determinantal projection point process with similarity kernel Kk. 
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32 Representation learning 


This chapter was written by Ben Poole and Simon Kornblith. 


32.1 Introduction 


Representation learning is a paradigm for training machine learning models to transform raw 
inputs into a form that makes it easier to solve new tasks. Unlike supervised learning, where the 
task is known at training time, representation learning often assumes that we do not know what task 
we wish to solve ahead of time. Without this knowledge, are there transformations of the input we 
can learn that are useful for a variety of tasks we might care about? 

One point of evidence that representation learning is possible comes from us. Humans can rapidly 
form rich representations of new classes [LST15] that can support diverse behaviors: finding more 
instances of that class, decomposing that instance into parts, and generating new instances from that 
class. However, it is hard to directly specify what representations we would like our machine learning 
systems to learn. We may want it make it easy to solve new tasks with small amounts of data, we 
may want to construct novel inputs or answer questions about similarities between inputs, and we 
may want the representation to encode certain information while discarding other information. 

In building methods for representation learning, the goal is to design a task whose solution requires 
learning an improved representation of the input instead of directly specifying what the representation 
should do. These tasks can vary greatly, from building generative models of the dataset to learning 
to cluster datapoints. Different methods often involve different assumptions on the dataset, different 
kinds of data, and induce different biases on the learned representation. In this chapter we first 
discuss methods for evaluating learned representations, then approaches for learning representations 
based on supervised learning, generative modeling, and self-supervised learning, and finally the theory 
behind when representation learning is possible. 


32.2 Evaluating and comparing learned representations 


How can we make sense of representations learned by different neural networks, or of the differences 
between representations learned in different layers of the same network? Although it is tempting 
to imagine representations of neural networks as points in a space, this space is high-dimensional. 
In order to determine the quality of representations and how different representations differ, we 
need ways to summarize these high-dimensional representations or their relationships with a few 
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13 Figure 32.1: Representation learning transforms input data (left) where data from different classes (color) are 
14 mized together to a representation (right) where attributes like class are more easily distinguished. Generated 
15 by vib_ demo.ipynb. 

16 


17 


18 . A . . 
relevant scalars. Doing so requires making assumptions about what structure in the representations 
19.. 

~is important. 

20 


21 
2292.2.1 Downstream performance 


23The most common way to evaluate the quality of a representation is to adapt it to one or more 
24 downstream tasks thought to be representative of real-world scenarios. In principle, one could choose 
25 any strategy to adapt the representation, but a small number of adaptation strategies dominate the 
26 literature. We discuss these strategies below. 

27 Clearly, downstream performance can only differ from pretraining performance if the downstream 
28 task is different from the pretraining task. Downstream tasks can differ from the pretraining task 
29 in their input distributions, target distributions, or both. The downstream tasks used to evaluate 
30 unsupervised or self-supervised representation learning often involve the same distribution of inputs 
3las the pretraining task, but require predicting targets that were not provided during pretraining. For 
32 example, in self-supervised visual representation learning, representations learned on the ImageNet 
33 dataset without using the accompanying labels are evaluated on ImageNet using labels, either by 
34 performing linear evaluation with all the data or by fine-tuning using subsets of the data. By contrast, 
35in transfer learning (Section 19.5.1), the input distribution of the downstream task differs from 
36 the distribution of the pretraining task. For example, we might pretrain the representation on a large 
37variety of natural images and then ask how the representation performs at distinguishing different 


38 species of birds not seen during pretraining. 

39 

40 £ x ; . 
1192-2.1.1 Linear classifiers and linear evaluation 


42 Linear evaluation treats the trained neural network as a fixed feature extractor and trains a 
43 linear classifier on top of fixed features extracted from a chosen network layer. In earlier work, 
44this linear classifier was often a support vector machine [Don+14; SR+14; Cha+14], but in more 
45recent work, it is typically an L?-regularized multinomial logistic regression classifier [ZIE17; KSL19; 
46 KZB19]. The process of training this classifier is equivalent to attaching a new layer to the chosen 
47 
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32.2. EVALUATING AND COMPARING LEARNED REPRESENTATIONS 


representation layer and training only this new layer, with the rest of the network’s weights frozen 
and any normalization/regularization layers set to “inference mode” (see Figure 32.2). 

Although linear classifiers are conceptually simple compared to deep neural networks, they are 
not necessarily simple to train. Unlike deep neural network training, objectives associated with 
commonly-used linear classifiers are convex and thus it is possible to find global minima, but it can 
be challenging to do so with stochastic gradient methods. When using SGD, it is important to 
tune both the learning rate schedule and weight decay. Even with careful tuning, SGD may still 
require substantially more epochs to converge when training the classifier than when training the 
original neural network [KZB19]. Nonetheless, linear evaluation with SGD remains a commonly used 
approach in the representation learning literature. 

When it is possible to maintain all features in memory simultaneously, it is possible to use full-batch 
optimization method with line search such as L-BFGS in place of SGD [KSL19; Rad+21]. These 
optimization methods ensure that the loss decreases at every iteration of training, and thus do not 
require manual tuning of learning rates. To obtain maximal accuracy, it is still important to tune 
the amount of regularization, but this can be done efficiently by sweeping this hyperparameter and 
using the optimal weights for the previous value of the hyperparameter as a warm start. Using a 
full-batch optimizer typically implies precomputing the features before performing the optimization, 
rather than recomputing features on each minibatch. Precomputing features can save a substantial 
amount of computation, since the forward passes through the frozen model are typically much more 
expensive than computing the gradient of the linear classifier. However, precomputing features also 
limits the number of augmentations of each example that can be considered. 

It is important to keep in mind that the accuracy obtainable by training a linear classifier on a finite 
dataset is only a lower bound on the accuracy of the Bayes-optimal linear classifier. The datasets 
used for linear evaluation are often small relative to the number of parameters to be trained, and 
the classifier typically needs to be regularized to obtain maximal accuracy. Thus, linear evaluation 
accuracy depends not only on whether it is possible to linearly separate different classes in the 
representation, but also on how much data is required to find a good decision boundary with a 
given training objective and regularizer. In practice, even an invertible linear transformation of a 
representation can affect linear evaluation accuracy. 


32.2.1.2 Fine-tuning 


It is also possible to adapt all layers from the pretraining task to the downstream task. This process 
is typically referred to as fine-tuning [HS06b; Gir+14]. In its simplest form, fine-tuning, like 
linear evaluation, involves attaching a new layer to a chosen representation layer, but unlike linear 
evaluation, all network parameters, and not simply those of the new layer, are updated according to 
gradients computed on the downstream task. The new layer may be initialized with zeros or using 
the solution obtained by training it with all other parameters frozen. Typically, the best results are 
obtained when the network is fine-tuned at a lower learning rate than was used for pretraining. 
Fine-tuning is substantially more expensive than training a linear classifier on top of fixed feature 
representations, since each training step requires backpropagating through multiple layers. However, 
fine-tuned networks typically outperform linear classifiers, especially when the pretraining and 
downstream tasks are very different [KSL19; AGM14; Cha+14; Azi+15]. Linear classifiers perform 
better only when the number of training examples is very small (~5 per class) [KSL19]. 
Fine-tuning can also involve adding several new network layers. For detection and segmentation 
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tasks, which require fine-grained knowledge of spatial position, it is common to add a feature pyramid 
network (FPN) [Lin+17b] that incorporates information from different feature maps in the pretrained 
network. Alternatively, new layers can be interspersed between old layers and initialized to preserve 
the network’s output. Net2Net [CGS15] follows this approach to construct a higher-capacity network 
that makes use of representations contained in the pretrained weights of a smaller network, whereas 
adapter modules [Hou+19] incorporate new, parameter-efficient modules into a pretrained network 
and freeze the old ones to reduce the number of parameters that need to be stored when adapting 
models to different tasks. 


IO 100 IN ID Jo Ie IW IN Ie 


32.2.1.3 Disentanglement 


Given knowledge about how a dataset was generated, for example that there are certain factors of 
variation such as position, shape, and color that generated the data, we often wish to estimate how 
well we can recover those factors in our learned representation. This requires using disentangled 
representation learning methods (see Section 21.3.1.1). While there are a variety of metrics for 
disentanglement, most measure to what extent there is a one-to-one correspondence between latent 
jg {actors and dimensions of the learned representation. 
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19 
2932.2.2 Representational similarity 


2l Rather than measure compatibility between a representation and a downstream task, we might seek 
22to directly examine relationships between two fixed representations without reference to a task. In 
23 this section, we assume that we have two sets of fixed representations corresponding to the same 
24n examples. These representations could be extracted from different layers of the same network 
or layers of different neural networks, and need not have the same dimensionality. For notational 
26 convenience, we assume that each set of representations has been stacked row-wise to form matrices 
27 X € R"*P: and Y € R”*?2 such that Xi, and Y; . are two different representations of the same 


28 example. 


29 
2221 Representational similarity analysis and centered kernel alignment 


32Representational similarity analysis (RSA) is the dominant technique for measuring similarity 
330f representations in neuroscience [KMBO08], but has also been applied in machine learning. RSA 
34reduces the problem of measuring similarity between representation matrices to measuring the 
35similarities between representations of individual examples. RSA begins by forming representational 
36similarity matrices (RSMs) from each representation. Given functions k : X x X > Rand k’: 
37’ x VY! = R that measure the similarity between pairs of representations of individual examples 
38a,a”/ € X, and y,y’ € V, the corresponding representational similarity matrices K, K’ € R"*” 
39contain the similarities between the representations of all pairs of examples K;; = k(X;,.,X;,,) and 
40 K;; = k'(Y;,:, Yj,.). These representational similarity matrices are transformed into a scalar similarity 
4ivalue by applying a matrix similarity function s(K, K’). 

42 The RSA framework can encompass many different forms of similarity through the selection of the 
43similarity functions k(-,-), k’(-,-), and s(-,-). How these functions should be selected is a contentious 
44topic [BS+20; Kril9]. In practice, it is common to choose k(x, ax’) = k’(x, x’) = corr|x, x’), the 
45 Pearson correlation coefficient between examples. s(-,-) is often chosen to be the Spearman rank 
46 correlation between the representational similarity matrices, which is computed by reshaping K and 
47 
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32.2. EVALUATING AND COMPARING LEARNED REPRESENTATIONS 


kK’ to vectors, computing the rankings of their elements, and measuring the Pearson correlation 
between these rankings. 

Centered kernel alignment (CKA) is a technique that was first proposed in machine learning 
literature [Cri+02; CMR12] but that can be interpreted as a form of RSA. In centered kernel 
alignment, the per-example similarity functions k and k’ are chosen to be positive semi-definite 
kernels so that K and K’ are kernel matrices. The matrix similarity function s is the cosine similarity 
between centered kernel matrices 

(HKH, HK'H)p 


K., K^ = 32.1 
SC) = aK Ale| HK’H |p’ ve 


where (A, B)p = vec(A)' vec(B) = tr(A' B) is the Frobenius product, and H = I — +11' is the 
centering matrix. As it is applied above, the centering matrix subtracts the means from the rows 
and columns of the similarity index. 

A special case of centered kernel alignment arises when k and k’ are chosen to be the linear kernel 
k(x, x') = k'(x,x2') =x! x’. In this case, K = XX' and K' = YY ', allowing for an alternative 
expression for CKA in terms of the similarities between pairs of features rather than pairs of examples. 
The representations themselves must first be centered by subtracting the means from their columns, 
yielding X = HX and Y = HY. Then, so-called linear centered kernel alignment is given by 


2 ee ee 
[XX IPE e XTX Te 


(32.2) 


Linear centered kernel alignment is equivalent to the RV coefficient [RE76] between centered features, 
as shown in [Kor+19]. 


32.2.2.2 Canonical correlation analysis and related methods 


Given two datasets (in this case, the representation matrices X and Y), canonical correlation 
analysis or CCA (Section 28.3.4.3) seeks to map both datasets to a shared latent space such that 
they are maximally correlated in this space. The itt pair of canonical weights (w;,w/) maximize 
the correlation between the corresponding canonical vectors p; = corr| X w;, Y wi] subject to the 


constraint that the new canonical vectors are orthogonal to previous canonical vectors, 
maximize corr|X w;, Y w;] 
subject to Vi<i Xwi L Xwj 
Vj<i Yw, silt Yw; 
|| Xw;|| = |Y w;|| = 1. 


(32.3) 


The maximum number of non-zero canonical correlations is the minimum of the ranks, p = 
min(rk(X),rk(Y)). 

The standard algorithm for computing the canonical weights and correlations [BG73] first decom- 
poses the individual representations as the product of an orthogonal matrix and a second matrix, 
X=QR:Q'Q=T and Ý = Q'R' : Q'TQ' = I. These decompositions can be obtained either 
by QR factorization or singular value decomposition. A second singular value decomposition of 
Q'Q’ =UXV' provides the quantities needed to determine the canonical weights and correlations. 
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native names, including Yanai’s GCD [Yan74; RBS84], Pillai’s trace, or the eigenspace overlap 
score [May+19]. 

ig Although CCA is a powerful tool, it suffers from the curse of dimensionality. If at least one 
igrepresentation has more neurons than examples and each neuron is linearly independent of the 
20 others, then all canonical correlations are 1. In practice, because neural network representations are 
21 high-dimensional, we can find ourselves in the regime where there are not enough data to accurately 
22estimate the canonical correlations. Moreover, even when we can accurately estimate the canonical 
23 correlations, it may be desirable for a similarity measure to place less emphasis on the similarity of 
24low-variance directions. 

25 Singular vector CCA (SVCCA) mitigates these problems by retaining only the largest principal 
26components of X and Y when performing CCA. Given the singular value decomposition of the 
27 representations X =USV' and Ý = U’D'V", SVCCA retains only the first k columns of U 
28 corresponding to the largest k singular values of ø = diag(%) (i.e., the k largest principal components) 
29and the first k’ columns of U” corresponding to the largest k’ singular values ø’ = diag(:’). With 
30 these singular value decompositions, the canonical correlations, vectors, and weights can then be 
31computed using the algorithm of Björck and Golub [BG73] described above, by setting 

32 

3 Q=Uns R=SreuViw =U, R=SareV ia (826) 
34 

35 Raghu et al. [Rag+17| suggest retaining enough components to explain 99% of the variance, i.e., 
36S8etting k to the minimum value such that ||o1.,||?/||o||? = 0.99 and k’ to the minimum value such 
37 that ||o}.,/||?/||o’||? = 0.99. However, for a fixed value of min(k, k’), CCA-based similarity measures 
3gincrease with the value of max(k, k’). Representations that require more components to explain 99% 
39 of the variance of representations may thus appear “more similar” to all other representations than 
4orepresentations with more rapidly decaying singular value spectra. In practice, it is often better to 
4set k and k’ to the same fixed value. 

42 Projection-weighted CCA (PWCCA) [MRB18] instead weights the correlations by a measure 
430f the variability in the original representation that they explain. The resulting similarity measure is 


1 

2 Specifically, the canonical correlations are the singular values p = diag(¥); the canonical vectors 
3 are the columns of XW = QU and YW’ = QV; and the canonical weights are W = R~'U and 
4 W'= RV. 

5 Two common strategies to turn the resulting vector of canonical correlations into a scalar are to 
6 take the mean squared canonical correlation, 

7 

= 2 2 TA 2 

gs Roca(X,Y) = |lella/p = 10° Q'|le/p, (32.4) 
9 

{90 the mean canonical correlation, 
1 p 

2 

2 p= pi/p=|1Q"Q'le/P, (32.5) 
13 i=l 
14 
15 where || - ||, denotes the nuclear norm. The mean squared canonical correlation has several alter- 
16 
17 

8 


44 

= P 

45 PWCCA(X,Y) = Zei, ai = ||(Y w) Y |h. (32.7) 
46 i= ai 

47 
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32.2. EVALUATING AND COMPARING LEARNED REPRESENTATIONS 


PWCCA enjoys somewhat widespread use in representational similarity literature, but it has po- 
tentially undesirable properties. First, it is asymmetric; its value depends on which of the two 
representations is used for the weighting. Second, it is unclear why weightings should be determined 
using the L! norm, which is not invariant to rotations of the representation. It is arguably more 
intuitive to weight the correlations directly by the amount of variance in the representation that 
the canonical component explains, which corresponds to replacing the Lt norm with the squared L? 
norm. The resulting similarity measure is 


2 bai UP? Z Ty 2 
Rir(X, Y) = SF, a =) (Yw) Y|]. (32.8) 
i=1 Yi 


As shown by Kornblith et al. [Kor+19], when pə > pı, this alternative similarity measure is equivalent 
to the overall variance explained by using linear regression to fit every neuron in Y using the 
representation X, 


Rig(X,Y) =1- > min |Y; — X B|? NY I}. (32.9) 


i=1 


Finally, there is also a close relationship between CCA and linear CKA. This relationship can 
be clarified by writing similarity indexes directly in terms of the singular value decompositions 
X =UXV' and Ý = U'X'V"". The left-singular vectors u; = U. i; correspond to the principal 
components of X normalized to unit length, and the squared singular values \; = ¥?, are the amount 
of variance that those principal components explain (up to a factor of 1/n). Given these singular 
value decompositions, Rac,, R?r, and linear CKA become: 


Pı P2 


Roca(X,Y) = 5 > (uj ui)” /Pi (32.10) 
i=1 j=1 
pa: P 2 Pr 
RiR(X,Y)=X_ Xy (uj uj) / S24, (32.11) 
i=1 j=l j=l 
Pi SP2 AA (ulu) 
CKAjinear X Y) = = = J a ui) (32.12) 


? 
Pr 2 p2 12 
V dint Ai VV &j=1 Xj 


Thus, these similarity indexes all involve the similarities between all pairs of principal components 
from X and Y, but place different weightings on these similarities according to the fraction of 
variance that these principal components explain. 


32.2.2.3 Comparing representational similarity measures 


What properties are desirable in a representational similarity measure is an open question, and this 
question may not have a unique answer. Whereas evaluations of downstream accuracy approximate 
real-world use cases for neural network representations, the goal of representational similarity is 
instead to develop understanding of how representations evolve across neural networks, or how they 
differ between neural networks with different architectures or training settings. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1050 


One way to taxonomize different similarity measures is through the transformations of a represen- 
tation that they are invariant to. The minimum form of invariance is invariance to permutation of a 
representation’s constituent neurons, which is needed because neurons in neural networks generally 
have no canonical ordering: For commonly-used initialization strategies, any permutation of a 
given initialization is equiprobable, and nearly all architectures and optimizers produce training 
trajectories that are equivariant under permutation. On the other hand, invariance to arbitrary 
invertible transformations, as provided by mutual information, is clearly undesirable, since many 
realistic neural networks are injective functions of the input [Gol+19] and thus there always exists an 
io invertible transformation between any pair of representations. In practice, most similarity measures in 
11common use are invariant to rotations (orthogonal transformations) of representations, which implies 
iginvariance to permutation. Similarity measures based solely on CCA correlations, such as Rae a and 
13/, are invariant to all invertible linear transformations of representations. However, SVCCA and 
14 PWCCA are not. 
i5 A different way to distinguish similarity measures is to investigate situations where we know 
isthe relationships among representations and to empirically evaluate their ability to recover these 
i7relationships. Kornblith et al. [Kor+19] propose a simple “sanity check”: Given two architecturally 
igidentical networks A and B trained from different random initializations, any layer in network A 
igShould be more similar to the architecturally corresponding layer in network B than to any other 
golayer. They show that, when considering flattened representations of CNNs, similarity measures 
21 based on centered kernel alignment satisfy this sanity check whereas other similarity measures do not. 
22 By contrast, when considering representations of individual tokens in Transformers, all similarity 
23measures perform reasonably well. However, Maheswaranathan et al. [Mah+19] show that both 
24CCA and linear CKA are highly sensitive to seemingly innocuous RNN design choices such as the 
25 activation function, even though analysis of the fixed points of the dynamics of different networks 
26 Suggests they all operate similarly. 
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28 

2932.3 Approaches for learning representations 

30 

31 The goal of representation learning is to learn a transformation of the inputs that makes it easier to 
32solve future tasks. Typically the tasks we want the representaiton to be useful for are not known when 
33 learning the representation, so we cannot directly train to improve performance on the task. Learning 
34such generic representations requires collecting large-scale unlabeled or weakly-labeled datasets, and 
35identifying tasks or priors for the representations that one can solve without direct access to the 
36 downstream tasks. Most methods focus on learning a parametric mapping z = f(a) that takes an 
37input x and transforms it into a representation z using a neural network with parameters 0. 

38 The main challenge in representation learning is coming up with a task that requires learning 
39a good representation to solve. If the task is too easy, then it can be solved without learning an 
4o0interesting transformation of the inputs, or by learning a shortcut. If a task is too different from 
41the downstream task that the representation will be evaluated on, then the representation may also 
42not be useful. For example, if the downstream task is object detection, then the representation 
43needs to encode both the identity and location of objects in the image. However, if we only care 
44about classification, then the representation can discard information about position. This leads 
45to approaches for learning representations that are often not generic: different training tasks may 
46 perform better for different downstream tasks. 

47 
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32.3. APPROACHES FOR LEARNING REPRESENTATIONS 


Supervised Generative Self-Supervised 


autoencoders image rotation contrastive learning 
completion prediction Schroff et al. 2016 
G @a r 


Figure 32.2: Approaches for representation learning from images. An input image is encoded through a deep 
neural network (green) to produce a representation (blue). An additional shallow or deep neural network 
(yellow) is often used to train the representation, but is thrown out after the representation is learned when 
solving downstream tasks. In the supervised case, the mapping from the representation to logits is typically 
linear, while for autoencoders the mapping from representation to images can be highly complex and stochastic. 
Unlike supervised or generative approaches, contrastive methods rely on other data points in the form of 
positive pairs (often created through data augmentation) and negative pairs (typically other datapoints) to 
learn a representation. 


In Figue 32.2, we outline three approaches we will discuss for representation learning. Supervised 
approaches train on large-scale supervised or weakly-supervised data using standard supervised 
losses. Generative approaches aim to learn generative models of the dataset or parts of a dataset. 
Self-supervised approaches are based on transformation prediction or multi-view learning, where 
we design a task that where labels can be easily synthesized without needing human input. 


32.3.1 Supervised Representation Learning and Transfer 


The first major successes in visual representation learning with deep learning came from networks 
trained on large labeled datasets. Following the discovery that supervised deep neural networks could 
outperform classical computer vision models for natural image classification [KSH12b; CMS12], it 
became clear that the representations learned by these networks could outperform handcrafted features 
used across a wide variety of tasks [Don+14; SR+14; Oqu+14; Gir+14]. Although unsupervised 
visual representation learning has recently achieved competitive results on many domains, supervised 
representation learning remains the dominant approach. 

Larger networks trained on larger datasets generally achieve better performance on both pretrain- 
ing and downstream tasks. When other design choices are held fixed, architectures that achieve 
higher accuracy during pretraining on natural image datasets such as ImageNet also learn better 
representations for downstream natural image tasks, as measured by both linear evaluation and 
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fine-tuned accuracy [KSL19; TL19; Zha+19a; Zha+21a; Abn+21]. However, when the domain shift 
from the pretraining task to the downstream task becomes larger (e.g., from ImageNet to medical 
imaging), the correlation between pretraining and downstream accuracy can be much lower [Rag+19; 
Ke+21; Abn+21]. Studies that vary pretraining dataset size generally find that larger pretraining 
datasets yield better representations for downstream tasks [HAE16; Mah+18; Kol+20; Zha+21a; 
Abn+21], although there is an interaction between model size and dataset. When training small 
models with the intention of transferring to a specific downstream task, it is sometimes preferable to 
pretrain on a smaller dataset that is more closely related to that task rather than a larger dataset 
io that is less closely related [Cui+18; Mah+18; Ngi+18; Kol+20], but larger models seem to derive 
i1greater benefit from larger, more diverse datasets [Mah+18; Kol+20]. 

12 Whereas scaling the architecture and dataset size generally improves both pretraining and down- 
13Stream accuracy, other design choices can improve pretraining accuracy at the expense of transfer, or 
i4vice versa. Regularizers such as penultimate layer dropout and label smoothing improve accuracy 
150n pretraining tasks but produce worse representations for downstream tasks [KSL19; Kor+2]]. 
16 Although most convolutional neural networks are trained with batch normalization, Kolesnikov 
izet al. [Kol+20] find that the combination of group normalization and weight standardization leads 
igto networks that perform similarly on pretraining tasks but substantially better on transfer tasks. 
ig Adversarial training produces networks that perform worse on pretraining tasks as compared to 
29Standard training, but these representations transfer better to other tasks [Sal+20]. For certain 
21 combinations of pretraining and downstream datasets, increasing the amount of weight decay on 
22the network’s final layer can improve transferability at the cost of pretraining accuracy [Zha-}21a; 
24 The challenge of collecting ever-larger pretraining datasets has led to the emergence of weakly- 
25Supervised representation learning, which eschews the expensive human annotations of datasets 
26 such as ImageNet and instead relies on data that can be readily collected from the Internet, but which 
a7may have greater label noise. Supervision sources include hashtags accompanying images on websites 
ggsuch as Instagram and Flickr [CG15; Iza+15; Jou+16; Mah+18], image labels obtained automatically 
29 using proprietary algorithms involving user feedback signals [Sun +17; Kol-+20], or image captions/alt 
30text [Li+17a; SPL20; DJ21; Rad+21,; Jia+21]. Hashtags and automatic labeling give rise to image 
31 Classification problems that closely resemble their more strongly supervised counterparts. The primary 
32 difference versus standard supervised representation learning is that the data are noisier, but in 
33practice, the benefits of more data often outweigh the detrimental effects of the noise. 

34 Image-text supervision has provided more fertile ground for innovation, as there are many 
35 different ways of jointly processing text and images. The simplest approach is again to convert the 
36data into an image classification problem, where the network is trained to predict which words or 
37n-grams appear in the text accompanying a given image [Li+17a]. More sophisticated approaches 
3g train image-conditional language models [DJ21] or masked language models [SPL20], which can make 
39 better use of the structure of the text. Recently, there has been a surge in interest in contrastive 
agimage/text pretraining models such as CLIP [Rad+21] and ALIGN |[Jia+21], details of which we 
41 discuss in Section 32.3.4. These models process images and text independently using two separate 
a2“towers,” and learn an embedding space where embeddings of images lie close to the embeddings of 
43 the corresponding text. As shown by Radford et al. [Rad+21], contrastive image/text pretraining 
a4learns high-quality representations faster than alternative approaches. 

45 Beyond simply learning good visual representations, pretrained models that embed image and text 
asin a common space enable zero-shot transfer of learned representations. In zero-shot transfer, an 
47 
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32.3. APPROACHES FOR LEARNING REPRESENTATIONS 


image classifier is constructed using only textual descriptions of the classes of interest, without any 
images from the downstream task. Early co-embedding models relied on pretrained image models and 
word embeddings that were then adapted to a common space [Fro+13], but contrastive image/text 
pretraining provides a means to learn co-embedding models end-to-end. Compared to linear classifiers 
trained using image embeddings, zero-shot classifiers typically perform worse, but zero-shot classifiers 
are far more robust to distribution shift [Rad+21]. 


32.3.2 Generative Representation Learning 


Supervised representation learning often fails to learn representations for tasks that differ significantly 
from the task the representation was trained on. How can we learn representations when the task we 
wish to solve differs a lot from tasks where we have large labeled datasets? 

Generative representation learning aims to model the entire distribution of a dataset q(x) 
with a parametric model pg(x). The hope of generative representation learning is that. if we can build 
models that can create all the data that we have seen, then we implicitly may learn a representation 
that can be used to answer any question about the data, not just the questions that are related to a 
supervised task for which we have labels. For example, in the case of digit classification, it is hard 
to collect labels for the style of a handwritten digit, but if the model has to product all possible 
handwritten digits in our dataset it needs to learn to produce digits with different styles. On the 
other hand, supervised learning to classify digits aims to learn a representation that is invariant to 
style. 

There are two main approaches for learning representations with generative models: (1) latent- 
variable models that aim to capture the underlying factors of variation in data with latent variables z 
that act as the representation (see the chapter on VAEs, Chapter 21), and (2) fully-observed models 
where a neural architecture is trained with a tractable generative objective (see the chapters on AR 
models, Chapter 22, and flow models, Chapter 23), and then a representation is extracted from the 
learned architecture. 


32.3.2.1 Latent-variable models 


One criterion for learning a good representation of the world is that it is useful for synthesizing 
observed data. If we can build a model that can create new observations, and has a simple set of 
latent variables, then hopefully this model will learn variables that are related to the underlying 
physical process that created the observations. For example, if we are trying to model a dataset of 
2d images of shapes, knowing the position, size, and type of the shape would enable easy synthesis 
of the image. This approach to learning is known as analysis-by-synthesis, and is a theory of 
perception that aims at identifying a set of underlying latent factors (analysis) that could be used to 
synthesize observations [Rob63; Bau74; LM03]. Our goal is to learn a generative model pg(x, z) over 
the observations x and latents z, with parameters 0. Given an observation x, performing the analysis 
step to extract a representation requires running inference to sample or compute the posterior mean 
of pọ(z|x). Different choices for the model pg(x, z) and inference procedure for pg(z|x) represent 
different ways of learning representations from a dataset. 

Early work on deep latent-variable generative models aimed to learn stacks of features often based 
on training simple energy-based models or directed sparse coding models, each of which could explain 
the previous set of latent factors, and which learned increasingly abstract representation [HOT06b; 
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Lee+09; Ran+06]. Bengio, Courville, and Vincent [BCV13] provides an overview of several methods 
based on stacking latent-variable generative modeling approaches to learn increasingly abstract 
representation. However greedy approaches to generative representation learning have failed to scale 
to larger natural datasets. 

If the generative process that created the data is simple and can be described, then encoding 
that structure into a generative model is a tremendously powerful way of learning useful and robust 
representations. Lake, Salakhutdinov, and Tenenbaum [LST15] and George et al. [Geo+17] use 
knowledge of how characters are composed of strokes to build hierarchical generative models with 
1orepresentations that excel at several downstream tasks. However, for many real-world datasets the 
11generative structure is not known, and the generative model must also be learned. There is often a 
igtradeoff between imposing structure in the generative process (such as sparsity) vs. learning that 
13S8tructure from data. 

14 Directed latent-variable generative models have proven easier to train and scale to natural datasets. 
15 Variational autoencoders (Chapter 21) train a directed latent-variable generative model with varia- 
16 tional inference, and learn a prior pg(z), decoder pg(x|z), and an amortized inference network qg(z|x) 
i7that can be used to extract a representation on new datapoints. Higgins et al. [Hig+17b] show 
18 B-VAEs (Section 21.3.1) are capable of learning latent variables that correspond to factors of variation 
ig0n simple synthetic datasets. Kingma et al. [Kin+14b] and Rasmus et al. [Ras +15] demonstrate 
20improved performing on semi-supervised learning with VAEs. While there have been several recent 
21 advances to scale up VAEs to natural datasets [VIK20b; Chi21b], none of these methods have yet led 
22to representations that are competitive for downstream tasks such as classification or segmentation. 
23 Adversarial methods for training directed latent-variable models have also proven useful for 
z2arepresentation learning. In particular, GANs (Chapter 26) trained with encoders such as BiGAN 
25 [DKD17], ALI [Dum+17], and [Che+16] were able to learn representations on small scale datasets 
a6 that performed well at object classification. The discriminators from GANs have also proven useful 
27for learning representations [RMC16b]. More recently, these methods were scaled up to ImageNet 
ggin BigBiGAN [DS19], with learned representations that performed strongly on classification and 
29 segmentation tasks. 

30 

ee Fully observed models 

33 The neural network architectures used in fully observed generative models can also learn useful 
34representations without the presence of latent-variables. ImageGPT [Che+20a] demonstrate that 
35an autoregressive model trained on pixels can learn internal representations that excel at image 
36 classification. Unlike with latent-variable models where the representation is often thought of as the 
37latent variables, ImageGPT extracted representations from the deterministic layers of the transformer 
3g architecture used to compute future tokens. Similar approaches have shown progress for learning 
39features in language modeling [Raf+20b], however alternative objectives, based on masked training 
4o(as in BERT, [Dev+19]), often leads to better performance. 

41 

4 


T 32.3.23 Autoencoders 
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44A related set of methods for representation learning are based on learning a representation from 
45 which the original data can be reconstructed. These methods are often called autoencoders (see 
46 Section 16.3.3), as the data is encoded in a way such that the input data itself can be recreated. 
47 
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However, unlike generative models, they cannot typically be used to synthesize observations from 
scratch or assign likelihoods to observations. Autoencoders learn an encoder that outputs a repre- 
sentation z = fọ(x), and a decoder gg(z) that takes the representation z and tries to recreate the 
input data, x. The quality of the approximate reconstruction , ĉ = gg(z) is often measured using a 
domain-specific loss, for example mean-squared error for images: 


£(6,¢) = D S liz = golfo). (32.13) 


xED 


If there are no constraints on the encoder or decoder, and the dimensionality of the representation 
z matches the dimensionality of the input x, then there exists a trivial solution to minimize the 
autoencoding objective: set both fg and gg to identity functions. In this case the representation has 
not learned anything interesting, and thus in practice an additional regularizer is often placed on the 
learned representation. 

Reducing the dimensionality of the representation z is one effective mechanism to avoid trivial 
solutions to the autoencoding objective. If both the encoder and decoder networks are linear, and 
the loss is mean-squared-error, then the resulting linear autoencoder model can learn the principal 
components of a dataset [Pla18]. 

Other methods maintain higher-dimensional representations by adding sparsity (for example, 
penalties on ||z||, in Ng et al. [Ng+11]) or smoothness regularizers [Rif+11], or adding noise to the 
input [Vin+08] or intermediate layers of the network [Sri+14b; PSDG14]. These added regularizers 
aim to bias the encoder and decoder to learn representations that are not just the identity function, 
but instead are nonlinear transformations of the input that may be useful for downstream tasks. See 
Bengio, Courville, and Vincent [BCV13] for a more detailed discussion of regularized autoencoders 
and their applications. A recent re-evaluation of several algorithms based on iteratively learning 
features by stacked regularized autoencoders have been shown to degrade performance versus training 
end-to-end from scratch [Pai+14]. However, we will see in Section 32.3.3.1 that denoising autoencoders 
have shown promise for representaiton learning in discrete domains and when applied with more 
complex noise and masking patterns. 


32.3.2.4 Challenges in generative representation learning 


Despite several success in generative representation learning, they have empirically fallen behind. 
Generative methods for representation learning have to learn to match complex high-dimensional 
and diverse training datasets, which requires modeling all axis of variation of the inputs, regardless 
of whether they are semantically relevant for downstream tasks. For example, the exact pattern of 
blades of grass in an image matter for generation quality, but are unlikely to be useful for many of 
the semantic evaluations that are typically used. Ways to bias generative models to focus on the 
semantic features and ignore “noise” in the input is an open area of research. 


32.3.3 Self-Supervised Representation Learning 


When given large amounts of labeled data, standard supervised learning is a powerful mechanism 
for training deep neural networks. When only presented with unlabeled data, building generative 
models requires modeling all variations in a dataset, and is often not explicit about what is the 
signal and noise that we aim to capture in a representation. The methods and architectures for 
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building these generative models also differs substantially from those of supervised learning, where 
largely feedforward architectures are used to predict low-dimensional representations. Instead of 
trying to model all aspects of variation, self-supervised learning aims to design tasks where labels 
can be generated cheaply, and help to encode the structure of what we may care about for other 
downstream tasks. Self-supervised learning methods allow us to apply the tools and techniques of 
supervised learning to unlabeled data by designing a task for which we can cheaply produce labels. 

In the image domain, several self-supervised tasks, also known as pretext tasks, have been proven 
effective for learning representations. Models are trained to perform these tasks in a supervised 
iofashion using data generated by the pretext task, and then the learned representation is transferred 
ito a target task of interest (such as object recognition), by training a linear classifier or fine-tuning 
ithe model in a supervised fashion. 
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32.3.3.1 Denoising and masked prediction 
5 


16 Generative representation learning is challenging because generative models must learn to produce 
17the entire data distribution. A simpler option is denoising, in which some variety of noise is added to 
isthe input and the model is trained to reconstruct the noiseless input. A particularly successful variant 
190f denoising is masked prediction, in which input patches or tokens are replaced with uninformative 
20masks and the network is trained to predict only these missing patches or tokens. 

21 The denoising autoencoder [Vin+08; Vin+10a] was the first deep model to exploit denoising for 
22representation learning. A denoising autoencoder resembles a standard autoencoder architecturally, 
23 but it is trained to perform a different task. Whereas a standard autoencoder attempts to reconstruct 
24its input exactly, a denoising autoencoder attempts to produce a noiseless output from a noisy input. 
25 Vincent et al. [Vin+08] argue that the network must learn the structure of the data manifold in order 
26to solve the denoising task. 

27 Newer approaches retain the conceptual approach of the denoising autoencoder, but adjust the 
28 masking strategy and objective. BERT [Dev+18] introduced the masked language modeling 
29task, where 15% of the input tokens are selected for masking and the network is trained to predict 
30them. 80% of the time, these tokens are replaced with an uninformative [MASK] token. However, the 
31 [MASK] token does not appear at fine-tuning time, producing some domain shift between pretraining 
32and fine-tuning. Thus, 10% of the time, tokens are replaced with random tokens, and 10% of the 
33time, they are left intact. BERT and the masked language modeling task have been extremely 
34influential for representation learning in natural language processing, inspiring substantial follow-up 
35 work [Liu+19c; Jos+20]. 

36 Although denoising-based approaches to representation learning were first employed for computer 
37vision, they received little attention for the decade that followed. Vincent et al. [Vin+08] greedily 
38 trained stacks of up to three denoising autoencoders that were then fine-tuned end-to-end to perform 
39digit classification, but greedy unsupervised pretraining was abandoned as it was shown that it 
40was possible to attain good performance using CNNs and other architectures trained end-to-end. 
41 Context encoders [Pat+16] mask contiguous image regions and train models to perform inpainting, 
42 achieving transfer learning performance competitive with other contemporary unsupervised visual 
43 representation learning methods. The use of image colorization as a pretext task [ZIE16; ZIE17] is also 
44related to denoising in that colorization involves reconstructing the original image from a corrupted 
45input, although generally color is dropped in a deterministic fashion rather than stochastically. 

46 Recently, the success of BERT in NLP has inspired new approaches to visual representation learning 
47 
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encoder —> decoder 


Figure 32.3: Masked autoencoders learn a representation of images by randomly masking out input patches 
and trying to predict them (from He et al. [He+21]). 


based on masked prediction. Image GPT [Che+20a] trained a Transformer directly upon pixels to 
perform a BERT-style masked image modeling task. While the resulting model achieves very high 
accuracy when fine-tuned CIFAR-10, the cost of self-attention is quadratic in the number of pixels, 
limiting applicability to larger image sizes. BEIT [Bao+22] addresses this challenge by combining the 
idea of masked image modeling with the patch-based architecture of Vision Transformers [Dos+21]. 
BEIT splits images into 16 x 16 pixel image patches and then discretizes these patches using a discrete 
VAE [Ram+21b]. At training time, 40% of tokens are masked. The network receives continuous 
patches as input and is trained to predict the discretized missing tokens using a softmax over all 
possible tokens. 

The masked autoencoder or MAE |He+22] further simplifies the masked image modeling task 
(see Figure 32.3). The MAE eliminates the need to discretize patches and instead predicts the 
constituent pixels of each patch directly using a shallow decoder trained with Lə loss. Because the 
MAE encoder operates only on the unmasked tokens, it can be trained efficiently even while masking 
most (75%) of the tokens. Models pretrained using masked prediction and then fine-tuned with 
labels currently hold the top positions on the ImageNet leaderboard among models trained without 
additional data [He+22; Don+21]. 


32.3.3.2 Transformation prediction 


An even simpler approach to representation learning involves applying a transformation to the 
input image and then predicting the transformation that was applied (see Figure 32.4). This 
prediction task is usually formulated as a classification problem. For visual representation learning, 
transformation prediction is appealing because it allows reusing exactly the same training pipelines 
as standard supervised image classification. However, it is not clear that networks trained to perform 
transformation prediction tasks learn rich visual representations. Transformation prediction tasks 
are potentially susceptible to “shortcut” solutions, where networks learn trivial features that are 
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Figure 32.4: Transformation prediction involves training neural networks to predict a transformation applied 

to the input. Context encoders predict the position of a second crop relative to the first. The jigsaw puzzle 


10task involves predicting the way in which patches have been permuted. Rotation prediction involves predicting 
11the rotation that was applied to the input. 
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l4 nonetheless sufficient to solve the task with high accuracy. For many years, self-supervised learning 
15methods based on transformation prediction were among the top-performing methods, but they have 
16since been displaced by newer methods based on contrastive learning and masked prediction. 

17 Some pretext tasks operate by cutting images into patches and training networks to recover the 
18 spatial arrangement of the patches. In context prediction [DGE15], a network receives two adjacent 
19image patches as input and is trained to recover their spatial relationship by performing an eight-class 
20 classification problem. To prevent the network from directly matching the pixels at the patch borders, 
21the two patches must be separated by a small variable gap. In addition, to prevent networks from 
22 using chromatic aberration to localize the patches relative to the lens, color channels must be distorted 
23 or stochastically dropped. Other work has trained networks to solve jigsaw puzzles by splitting 
24images into a 3 x 3 grid of patches [NF 16]. The network receives shuffled patches as input and learns 
25to predict how they were permuted. By limiting the permutations to a subset of all possibilities, the 
26 jigsaw puzzle task can be formulated as a standard classification task [NF 16]. 

27 Another widely used pretext task is rotation prediction [GSK18], where input images are rotated 0, 
2890, 180, or 270 degrees and networks are trained to classify which rotation was applied. Although this 
29task is extremely simple, the learned representations often perform better than those learned using 
30 patch-based methods [GSK18; KZB19]. However, all approaches based on transformation prediction 
3lcurrently underperform masked prediction and multiview approaches on standard benchmark datasets 


32such as CIFAR-10 and ImageNet. 
33 


~ 32.3.4 Multiview representation learning 


36 The field of multiview representation learning aims to learn a representation where “similar” 
37inputs or views of an input are mapped nearby in the representation space, and “dissimilar” inputs are 
38mapped further apart. This representation space is often high-dimensional, and relies on collecting 
39data or designing a task where one can generative “positive” pairs of examples that are similar, 
4oand “negative” pairs of examples that are dissimilar. There are many motivations and objectives 
41for multiview representation learning, but all rely on coming up with sets of positive pairs, and a 
42mechanism to prevent all representations from collapsing to the same point. Here we use the term 
43 multiview representation learning to encompass contrastive learning which combines positive and 
44negative pairs, metric learning, and “non-contrastive” learning which eliminates the need for negative 
45 pairs. 

46 Unlike generative methods for representation learning, multiview representation learning makes 
47 
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"happy dog" 


SimCLR CLIP "angry cat" 


Figure 32.5: Positive and negative pairs used by different multiview representation learning methods. 


it easy to incorporate prior knowledge about what inputs should be closer in the embedding space. 
Furthermore, these inputs need not be from the same modality, and thus multiview representaiton 
learning can be applied with rich multimodal datasets. The simplicity of the way in which prior 
knowledge can be incorporated into a model through data has made multiview representation learning 
one of the most powerful and performant methods for learning representations. 

While there are a variety of methods for multiview representation learning, they all involve a 
repulsion component that pulls positive pairs closer together in embedding space, and a mechanism 
to prevent collapse of the representation to a single point in embedding space. We begin by describing 
loss functions for multiview representation learning and how they combine attractive and repulsive 
terms to shape the representation, then discuss the role of view generation, and finally practical 
considerations in deploying multiview representation learning. 


32.3.4.1 View selection 


Multiview representation learning depends on a datapoint or “anchor” x, a positive example x* that 
x will be attracted to, and zero or more negative examples z~ that x is repelled from. We assume 
access to a data-generating process for the positive pair: pt (a, xt), and a process that generates the 
negative examples given the datapoint æ: p~(a~|x). Typically pt (x, a+) generate (x, xt) that are 
different augmentations of an underlying image from the dataset, and x~ represents an augmented 
view of a different random image from the dataset. The generative process for x7 is then independent 
of x, i.e. p` (a7 |x) = p~ (a7). 

The choice of views used to generate positive and negative pairs is critical to the success of 
representation learning. Figure 32.5 shows the positive pair (v,2*) and negative «~ for several 
methods which we discuss below: SimCLR, CMC, SupCon, and CLIP. 

SimCLR [Che+20c] creates positive pairs by applying two different data augmentations defined by 
transformations t and t’ to an initial image xo twice: x = t(xo), z+ = t(xo). The data augmentations 
used are random crops (with horizontal flips and resize), color distortion, and Gaussian blur. The 
strengths of these augmentations (e.g. the amount of blur) impact performance and are typically 
treated as a hyperparameter. 

If we access to additional information, such as a categorical label, we can use this to select positive 
pairs with the same label, and negative pairs with different labels. The resulting objective, when 
used with a contrastive loss, is called SupCon [Kho+20], and resembles Neighborhood Component 
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Analysis [Gol-+-04]. It was shown to imrove robustness when compared to standard supervised 
learning. 

Contrastive Multiview Coding (CMC) [TKI20] generates views by splitting an initial image 
into orthogonal dimensions, such as the luma and chroma dimensions. These views are now no 
longer in the same space (or same dimensionality), and thus we must learn different encoders for the 
different inputs. However, the output of these encoders all live in the same-dimensional embedding 
space, and can be used in contrastive losses. At test-time, we can then combine embeddings from 
these different views through averaging or concatenation. 

Views do not need to be from the same modality. CLIP [Rad+21] uses contrastive learning on 
image-text pairs, where x is an image, and x* and x7 are text descriptions. When applied to massive 
atasets of image-text pairs scraped from the Internet, CLIP is able to learn robust representations 
without any of the additional data augmentation needed by SimCLR or other image-only contrastive 
methods. 

In most contrastive methods, negative examples are selected by randomly choosing x* from 
ig Other elements in a minibatch. However, if the batch size is small it may be the case that none of 
izthe negative examples are close in embedding space to the positive example, and so learning may 
igbe slow. Instead of randomly choosing negatives, they may be chosen more intelligently through 
ighard negative mining that selects negative examples that are close to the positive example in 
20 embedding space [Fag+18]. This typically requires maintaining and updating a database of negative 
21 examples over the course of training; this incurs enough computational overhead that the technique 
22is infrequently used. However, reweighting examples within a minibatch can also lead to improved 
23 performance [Rob+21]. 

24 The choice of positive and negative views directly impacts what features are learned and what 
25 invariances are encouraged. Tian et al. [Tia+20] discusses the role of view selection on the learned 
2s representations, showing how choosing positives based on shared attributes (as in SupCon) can lead 
27 to learning those attributes or ignoring them. They also present a method for learning views (whereas 
ggall prior approaches fix views) based on targeting a “sweet spot” in the level of mutual information 
29 between the views that is neither too high or too low. However, understanding what views will work 
30 well for what downstream tasks remains an open area of study. 

31 

3232.3.4.2 Contrastive losses 
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33 Given p* and p~, we seek loss functions that learn an embedding fọ(x) where x and «* are close in 
— the embedding space, while x and x~ are far apart. This is called metric learning. 

= Chopra, Hadsell, LeCun, et al. [CHL+05] present a family of objectives that implements this 
~ intuition by enforcing the distance between negative pairs to always be at least € bigger than the 
3g distance between positive pairs. The contrastive loss as instantiated in [HCLO06] is: 


39 Lcontrastive = leat ja [Ilfo (£x) > folat)? Te max(0, E> Il fo(x) — fo(27)ll?] $ (32.14) 


1, This loss pulls together the positive pairs by making the squared lə distance between them small, 
~and tries to ensure that negative pairs are at least a distance of € apart. One challenge with using 
~ the contrastive loss in practice is tuning the hyperparameter e. 

Similarly, the triplet loss [SKP15] tries to ensure that the positive pair (x, x*) is always at least 
yp 8ome distance e closer to each other than the negative pair (x, x7 ): 


46 Ltriplet = Legt a [max(0, \| fo (a) = fold)? = folz) BE fola)? + €)| i (32.15) 
47 
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A downside to the triplet loss approach is that one has to be careful about choosing hard negatives: 
if the negative pair is already sufficiently far away then the objective function is zero and no learning 
occurs. 

An alternative contrastive loss which has gained popularity due to its lack of hyperparameters and 
empirical effectiveness is known as the InfoNCE loss [OLV18b] or the Multi-Class N-pair loss 
[Soh16]: 


. exp fol)" go(e*) 
TA ba T he M T = 
exp folz)” gg(at) +} 7i21 exp folx)" 94(2; ) 


where M are the number of negative examples. Typically the embeddings f(x) and g(x’) are 42- 
normalized, and an additional hyperparameter T can be introduced to rescale the inner products 
[Che+20c]. Unlike the triplet loss, which uses a hard threshold of €, Līnfonce can always be improved 
by pushing negative examples further away. Intuitively, the InfoNCE loss ensures that the positive 
pair is closer together than any of the M negative pairs in the minibatch. The InfoNCE loss 
can be related to a lower bound on the mutual information between the input x and the learned 
representation z [OLV18b; Poo+19al: 


Linfonce = —E,, z f (32.16) 


I(X; Z) > log M — LinfoNcCE, (32.17) 


and has also been motivated as a way of learning representations through the InfoMax principle 
[OLV18b; Hje+18; BHB19]. When applying the InfoNCE loss to parallel views that are the same 
modality and dimension, the encoder fọ for the anchor x and the positive and negative examples gẹ 
can be shared. 


32.3.4.3 Negative-free losses 


Negative-free representation learning (sometimes called non-contrastive representation 
learning) learns representations using only positive pairs, without explicitly constructing negative 
pairs. Whereas contrastive methods prevent collapse by enforcing that positive pairs are closer 
together than negative pairs, negative-free methods make use of other mechanisms. One class of 
negative-free objectives includes both attractive terms and terms that prevent collapse. Another 
class of methods uses objectives that include only attractive terms, and instead relies on the learning 
dynamics to prevent collapse. 
The Barlow Twins loss [Zbo+21] is 


Lpr = Sa — Cu)? + D 5c} (32.18) 


i=1 i=1 jf#i 


where C is the cross-correlation matrix between two batches of features that arise from the two views. 
The first term is an attractive term that encourages high similarity between the representations of 
the two views, whereas the second term prevents collapse to a low-rank representation. The loss is 
minimized when C is the identity matrix. Similar losses based on ensuring the variance of features 
being non-zero have also been useful for preventing collapse [BPL21b]. The Barlow Twins loss can 
be related to kernel-based independence criterion such as HSIC which have also been useful as losses 
for representation learning [Li+21; Tsa+21]. 
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BYOL (Bootstrap Your Own Latent) [Gri+20] and SimSiam [Che+20c] simply minimizes the 
mean squared error between two representations: 


Leyor = Ex æ+ [Il9o(fo(x)) — for(e*)||"] - (32.19) 


Following Grill et al. [Gri+20], gg is known as the predictor, f is the online network, and fg is the 
target network. When optimizing this loss function, weights are backpropagated to update @ and 
0, but optimizing 6’ directly leads the representation to collapse [Che+20c]. Instead, BYOL sets 
0’ as an exponential moving average of 0, and SimSiam sets 6’ «+ 0 at each iteration of training. 
~The reasons why BYOL and SimSiam avoid collapse are not entirely clear, but Tian, Chen, and 
~~ Ganguli [TCG21] analyze the gradient flow dynamics of a simplified linear BYOL model and show 
~ that collapse can indeed be avoided given properly set hyperparameters. 

DINO (self-distillation with no labels) [Car+21] is another non-contrastive loss that relies on 
the dynamics of learning to avoid collapse. Like BYOL, DINO uses a loss that consists only of an 
attractive term between an online network and a target network formed by an exponential moving 
average of the online network weights. Unlike BYOL, DINO uses a cross-entropy loss where the 
target network produces the targets for the online network, and avoids the need for a predictor 
network. The DINO loss is: 


Es IO 100 IN ID Jo Ie IW IN IR 


N JN 
IN IS le le S ls lale le Is E 


Lpwo = Egat [H (for (x)/T, center( fg(x*))/r’)] ; (32.20) 


22where, with some abuse of notation, center is a mean-centering operation applied across the minibatch 
23¢hat contains x+. Centering the output of the target network is necessary to prevent collapse to a 
24 single “class”, whereas using a lower temperature 7’ < 7 for the target network is necessary to prevent 
25 collapse to a uniform distribution. The DINO loss provides marginal gains over the BYOL loss when 
26 performing self-supervised representation learning with Vision Transformers on ImageNet [Car+21]. 
27 

5 32.3.4.4 Tricks of the trade 


30 Beyond view selection and losses, there are a number of useful architectures and modifications that 
31enable more effective multiview representation learning. 

32 Normalizing the output of the encoders and computing cosine similarity instead of predicting 
33 unconstrained representations has shown to improve performance [Che+20c]. This normalization 
34bounds the similarity between points between —1 and 1, so an additional temperature parameter T 
35is typically introduced and fixed or annealed over the course of learning. 

36 While the learned representation with multiview learning are often useful for downstream tasks, the 
37losses when combined with data augmentation typically lead to too much invariance for some tasks. 
38Instead, one can extract an earlier layer in the encoder as the representation, or alternatively, add an 
39additional layer known as a projection head to the encoder before computing the loss [Che+20c]. 
40 When training we compute the loss on the output of the projection head, but when evaluating the 
41 quality of the representation we discard this additional layer. 

42 Given the summation over negative examples in the denominator of the InfoNCE loss, it is often 
43 sensitive to the batch size used for training. In practice, large batch sizes of 4096 or more are needed 
44to achieve good performance with this loss, which can be computationally burdensome. MoCo 
45(Momentum Contrast) [He+20] introduced a memory queue to store negative examples from previous 
46 minibatches to expand the size of negatives at each iteration. Additionally, they use a momentum 
47 
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32.4. THEORY OF REPRESENTATION LEARNING 


encoder, where the encoder for the positive and negative examples uses an exponential moving 
average of the anchor encoder parameters. This momentum encoder approach was also found useful 
in BYOL to prevent collapse. As in BYOL, adding an extra predictor network that maps from the 
online network to the target network has shown to improve the performance of MoCo, and removes 
the requirement of a memory queue [CXH21]. 

The backbone architectures of the encoder networks play a large role in the quality of representations. 
For representation learning in vision, recent work has switched from ConvNet-based backbones 
to Vision Transformers, resulting in larger-scale models with improved performance on several 
downstream tasks [CXH21]. 


32.4 Theory of Representation Learning 


While deep representation learning has replaced hand-designed features for most applications, the 
theory behind what features are learned and what guarantees these methods provide are limited. 
Here we review several theoretical directions in understanding representation learning: identifiability, 
information maximization, and transfer bounds. 


32.4.1 Identifiability 


In this section, we assume a latent-variable generative model that generated the data, where z ~ p(z) 
are the latent variables, and x = g(z) is a determinstic generator that maps from the latent variables 
to observations. Our goal is to learn a representation h = fg(x) that inverts the generative model 
and recovers h = z. If we can do this, we say the model is identifiable. Often times we are not 
able to recover the true latent variables exactly, for example the dimensions of the latent variables 
may be permuted, or individual dimensions may be transformed version of an underlying latent 
variable: h; = f;(z;). Thus most theoretical work on identifiability focuses on the case of learning a 
representation that can be permuted and elementwise transformed to match the true latent variables. 
Such representations are referred to as disentangled as the dimensions of the learned representaiton 
do not mix together multiple dimensions of the true latent variables. 

Methods for recovering are typically based around latent-variable models such as VAEs combined 
with various regularizers (see Section 21.3.1.1). While several publications showed promising empirical 
progress, a large-scale study by Locatello et al. [Loc+20a] on disentangled representation learning 
methods showed that several existing approaches cannot work without additional assumptions on 
the data or model. Their argument relies on the observation that we can form a bijection f that 
takes samples from a factorial prior p(z) = [], pi(zi) and maps to z’ = f(z) that (a) preservers the 
marginal distribution, and (b) has entirely entangled latents (each dimension of z influences every 
dimension of z’). Transforming the marginal in this way changes the representation, but preserves 
the marginal likelihood of the data, and thus one cannot use marginal likelihood alone to identify 
or distinguish between the entangled and disentangled model. Empirically, they show that past 
methods largely succeeded due to careful hyperparameter selection on the target disentanglement 
metrics that require supervised labels. While further work has developed unsupervised methods 
for hyperparameter that address several of these issues [Dua+20], at this point there are no known 
robust methods for learning disentangled representations without further assumptions. 

To address the empirical and theoretical gap in learning disentangled representations, several papers 
have proposed using additional sources of information in the form of weakly-labeled data to provide 
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guarantees. In theoretical work on nonlinear ICA [RMK21; Khe+20; Hal+21], this information comes 
in the form of additional observations for each datapoint that are related to the underlying latent 
variable through an exponential family. Work on causal representation learning has expanded 
the applicability of these methods and highlighted the settings where such strong assumptions on 
weakly-labeled data may be attainable [Sch+21lc; WJ21; Rei+22| 
Alternatively, one can assume access to pairs of observations where the relationship between latent 
variables is known. In Shu et al. [Shu+19b], they show that one can provably learn a disentangled 
representation of data when given access to pairs of data where only one of the latent variables is 
19 changed at a time. In real world datasets, having access to pairs of data like this is challenging, as not 
1 all the latent-variables of the model may be under the control of the data collector, and covering the 
i2full space of settings of the latent variable may be probihitively expensive. Locatello et al. [Loc+20b] 
13develops this method further but leverages a heuristic to detect which latent variable has changed, 
i4and shows this performs empirically well, and under some restricted settings may lead to learning 
15 disentangled representations. It remains an open question what realistic assumptions can be made 
igabout the nature of a dataset that will easily enable learning a disentangled representaiton, and 
i7practical methods on arbitrary datasets remain out of reach. 
8 
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21 When learning representations of an input x, one desiderata is to preserve as much information about 
22x as possible. Any information we discard cannot be recovered, and if that information is useful for a 
23 downstream task then performance will decrease. Early work on understanding biological learning by 
24Linsker [Lin88c] and Bell and Sejnowski [BS95b] argued that information maximization or InfoMax 
25is a good learning principle for biological systems as it enables the downstream processing systems 
26 access to as much sensory input as possible. However, these biological systems aim to communicate 
27information subject to strong constraints, and these constraints can likely be tuned over time by 
28evolution to sculpt the kinds of representations that are learned. 

29 When applying information maximization to neural networks, we are often able to realize trivial 
30solutions which biological systems may not face: being able to losslessly copy the input. Information 
31theory does not “color” the bits, it does not tell us which bits of an input are more important 
32than others. Simply sending the image losslessly maximizes information, but does not provide a 
33transformation of the input that can improve performance according to the metrics in Section 32.2. 
34 Architectural and optimization constraints can guide the bits we learn and the bits we dispose of, 
35 but we can also leverage additional sources of information, for example labels, to identify which bits 
36to extract. 

37 The information bottleneck method (Section 5.6) aims to learn representations Z of an input 
38 X that are predictive of another observed variable Y, while being as compressed as possible. The 
39 observed variable Y guides the bits learned in the representation Z towards those that are predictive, 
40and penalizes content that does not predict Y. We can formalize the information bottleneck as an 
41optimization problem [TPBOO]: 

42 

43 maximizegl(Z;Y) — BI(X; Z). (32.21) 
44 

45 Estimating mutual information in high dimensions is challenging, but we can form variational 
46 bounds on mutual information that are amenable to optimization with modern neural networks, such 
47 
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32.4. THEORY OF REPRESENTATION LEARNING 


as Variational Information bottleneck (VIB, see Section 5.6.2). Approaches built on VIB have shown 
improved robustness to adversarial examples and natural variations [FA20]. 

Unlike information bottleneck methods, many recent approaches motivated by InfoMax have no 
explicit compression objective [Hje+18; BHB19; OLV18b]. They aim to maximize information subject 
to constraints, but without any explicit penalty on the information contained in the representation. 

In spite of the appeal of explaining representation learning with information theory, there are a 
number of challenges. One of the greatest challenges in applying information theory to understand 
the content in learned representations is that most learned representations have determinstic encoders, 
z = fo(x) that map from a continuous input x to a continuous representation z. These mappings 
can typically preserve infinite information about the input. As mutual information estimators scale 
poorly with the true mutual information, estimating MI in this setting is difficult and typically results 
in weak lower bounds. 

In the absence of constraints, maximizing information between an input and a learned representation 
has trivial solutions that do not result in any interesting transformation of the input. For example, 
the identity mapping z = x maximizes information but does not alter the input. Tschannen et al. 
[Tsc+-19] show that for invertible networks where the true mutual information between the input 
and representation is infinite, maximizing estimators of mutual information can result in meaningful 
learned represenations. This highlights that the geometric dependence and bias of these estimators 
may have more to do with their success for representation learning than the information itself (as it 
is infinite throughout training). 

There have been several proposed methods for learning stochastic representations that constrain 
the amount of information in learned representations [Ale+17]. However, these approaches have not 
yet resulted in improved performance on most downstream tasks. Fischer and Alemi [FA20] shows 
that constraining information can improve robustness on some benchmarks, but scaling up models 
and datasets with determinstic representations currently presents the best results [Rad+21]. More 
work is needed to identify whether constraining information can improve learned representations. 
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3 3 Interpretability 


This chapter was written by Been Kim and Finale Doshi- Velez. 


33.1 Introduction 


As machine learning models become increasingly commonplace, there exists increasing pressure to 
ensure that these models’ behaviors align with our values and expectations. It is essential that models 
that automate even mundane tasks (e.g. processing paperwork, flagging potential fraud) do not harm 
their users or society at large. Models with large impacts on health and welfare (e.g. recommending 
treatment, driving autonomously) must not only be safe but often also function collaboratively with 
their users. 

However, determining whether a model is harmful is not easy. Specific performance metrics may 
be too narrowly focused—e.g. just because an autonomous car stays in lane does not mean it is safe. 
Indeed, the narrow objectives used in common decision formalisms such as Bayesian decision theory 
(Main Section 34.1), multi-step decision problems (Main Chapter 34), and reinforcement learning 
(Main Chapter 35) can often be easily exploited (e.g., reward hacking). Incomplete sets of metrics 
also result in models that learn shortcuts that do not generalize to new situations (e.g., [Gei+20b]). 
Even when one knows the desired metrics, those metrics can be hard to estimate with limited data 
or a distribution shift (Main Chapter 19). Finally, normative concepts, such as fairness, may be 
impossible to fully formalize. As a result, not only unexpected and irreversible harms may occur 
(e.g., an adverse drug reaction) but more subtle harms may go unnoticed until sufficient reporting 
data accrues [Amo-+ 16]. 

Interpretability allows human experts to inspect a model. Alongside traditional statistical measures 
of performance, this human inspection can help expose issues and thus mitigate potential harms. 
Exposing the workings of a model can also help people identify ways to incorporate information they 
have into a final decision. More broadly, even when we are satisfied with a model’s performance, 
we may be interested in understanding why they work to gain scientific and operational insights. 
For example, one might gain insights in language structure by asking why a language model 
performs so well; understanding why patient data cluster along particular axes may result in a better 
understanding of disease and the common treatment pathways. Ultimately, interpretation helps 
humans to communicate better with machines to accomplish our tasks better. 

In this chapter, we lay out the role and terminologies in interpretable ML before introducing 
methods, properties and evaluation of interpretability methods. 
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33.1.1 The Role of Interpretability: Unknowns and Under-Specifications 


As noted above, ensuring that models behave as desired is challenging. In some cases, the desired 
behavior can be guaranteed by design, such as certain notions of privacy via differentially-private 
learning algorithms or some chosen mathematical metric of fairness. In other cases, tracking various 
metrics, such as adverse events or subgroup error rates, may be the appropriate and sufficient 
way to identify concerns. Much of this textbook deals with uncertainty quantification: basic 
models in Main Chapter 3, Bayesian neural networks in Main Chapter 17, Gaussian processes in 
10 Main Chapter 18). When well-calibrated uncertainties can be computed, they may provide sufficient 
1lwarning that a model’s output may be suspect. 

12 However, in many cases, the ultimate goal may be fundamentally impossible to fully specify and 
13thus formalize. For example, Main Section 20.4.8 discusses the challenge of evaluating the quality of 
14samples from a generative model. In such cases, human inspection of the machine learning model 
15may be necessary. Below we describe several examples. 
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17 Blindspot Discovery. Inspection may reveal blindspots in our modeling, objective, or data 
18[Bad+18; Zec+18b; Gur+18]. For example, suppose a company has trained a machine learning system 
19for credit scoring. The model was trained on a relatively affluent, middle-aged population, and now 
20the company is considering using it on a different, college-aged population. Suppose that inspection 
210f the model reveals that it relies heavily on the applicant’s mortgage payments. Not only might this 
22suggest that the model might not transfer well to the college population, but it might encourage 
23us to check for bias in the existing application because we know historical biases have prevented 
24certain populations from achieving home ownership (something that a purely quantitative definition 
25of fairness may not be able to recognize). Indeed, the most common application of interpretability in 
26industry settings is for engineers to debug models and make deployment decisions [Pail]. 

27 

28 Novel Insights. Inspection may catalyze the discovery of novel insights. For example, suppose 
29an algorithm determines that surgical procedures fall into three clusters. The surgeries in one of 
30the clusters of patients seem to consistently take longer than expected. A human inspecting these 
31clusters may determine that a common factor in the cluster with the delays is that those surgeries 
32occur in a different part of the hospital, a feature not in the original dataset. This insight may result 
33in ideas to improve on-time surgery performance. 

34 

35 Human+ML Teaming. Inspection may empower effective human+ML interaction and 
36teaming. For example, suppose an anxiety treatment recommendation algorithm reveals the pa- 
37 tient’s comorbid insomnia constrained its recommendations. If the patient reports that they no longer 
38have trouble sleeping, the algorithm could be re-run with that constraint removed to get additional 
39treatment options. More broadly, inspection can reveal places where people may wish to adjust the 
40 model, such as correcting an incorrect input or assumption. It can also help people use only part 
41of a model in their own decision-making, such as using a model’s computation of which treatments 
42 unsafe vs. which treatments are best. In these ways, the human+ML team may be able to produce 
43 better combined performance than either alone (e.g. [Ame+19; Kam16]). 

44 


45 Individual-Level Recourse. Inspection can help determine whether a specific harm or error 
46 happened in a specific context. For example, if a loan applicant knows what features were used to 
47 
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deny them a loan, they have a starting point to argue that an error might have been made, or that 
the algorithm denied them unjustly. For this reason, inspectability is sometimes a legal requirement 
[Zer+19; GF17; Coul6]. 


As we look at the examples above, we see that one common element is that interpretability is 
needed when we need to combine human insights with the ML algorithm to achieve the ultimate goal. 
1 However, looking at the list above also emphasizes that beyond this very basic commonality, each 
application and task represents very different needs. A scientist seeking to glean insights from a 
clustering on molecules may be interested in global patterns—such as all molecules with certain 
loop structures are more stable—and be willing to spend hours puzzling over a model’s outputs. In 
contrast, a clinician seeking to make a specific treatment decision may only care about aspects of the 
model relevant to the specific patient; they must also reach their decision within the time-pressure of 
an office visit. This brings us to our most important point: The best form of explanation depends on 
the contezt; interpretability is a means to an end. 


33.1.2 Terminology and Framework 


In broad strokes, “to interpret means to explain or present in understandable terms,” [Mer] [to a 
human]. Understanding, in turn, involves an alignment of mental models. In interpretable machine 
learning, that alignment is between what (perhaps part of) the machine learning model is doing and 
what the user thinks the model is doing. 

As a result, interpretable machine learning ecosystem includes not only standard machine learning 
(e.g., a prediction task) but also what information is provided to the human user, in what context, 
and the user’s ultimate goal. The broader socio-technical system—the collection of interactions 
between human, social, organizational, and technical (hardware and software) factors—cannot be 
ignored [Sel+-19]. The goal of interpretable machine learning is to help a user do their task, with their 
cognitive strengths and weaknesses, with their focus and distractions [Mil19]. Below we define the 
key terms of this expanded ecosystem and describe how they relate to each other. Before continuing, 
however, we note that the field of interpretable machine learning is relatively new, and a consensus 
around terminology is still evolving. Thus, it is always important to define terms. 

Two core social or human-factors elements in interpretable machine learning are the context 
and the end-task. 


Context. We use the term context to describe the setting in which an interpretable machine 
learning system will be used. Who is the user? What information do they have? What constraints 
are present on their time, cognition, or attention? We will use the terms context and application 
interchangeably [Sta]. 


End-task. We use the term end-task to refer to the user’s ultimate goal. What are they ultimately 
trying to achieve? We will use the terms end-task and downstream tasks interchangeably. 


Three core technical elements in interpretable machine learning are the method, the metrics, and 
the properties of the methods. 


1. We emphasize that interpretability is different from manipulation or persuasion, where the goal is to intentionally 
deceive or convince users of a predetermined choice. 
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Method. How do we does the interpretability happen? We use the term explanation to mean 
the output provided by the method to the user: interpretable machine learning methods provide 
explanations to the users. If the explanation is the model itself, we call the method inherently 
interpretable or interpretable by design. In other cases, the model may be too complex for a human 
to inspect it in its entirety: perhaps it is a large neural network that no human could expect to 
comprehend; perhaps it is a medium-sized decision tree that could be inspected if one had twenty 
minutes but not if one needs to make a decision in two minutes. In such cases, the explanation may 
19 be a partial view of the model, one that is ideally suited for performing the end-task in the given 
icontext. Finally, we note that even inherently interpretable models do not reveal everything: one 
igmight be able to fully inspect the function (e.g. a two-node decision tree) but not know what data it 
13was trained on or which data points were most influential. 
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Metrics. How is the interpretability method evaluated? Evaluation is one of the most essential 
ig and challenging aspects of interpretable machine learning, because we are interested in the end-task 
i7performance of the human, when explanation is provided. We call this the downstream performance. 
igJust as different goals in ML require different metrics (e.g., positive predictive value, log likelihood, 
ig AUC), different contexts and end-tasks will have different metrics. For example, the model with 
20 the best predictive performance (e.g., log likelihood loss) may not be the model that results in the 
21 best downstream performance. 


22 

23 Properties. What characteristics does the explanation have in relation to the model, the context 
24and the end-tasks? Different contexts and different end-tasks might require different properties. 
25 For example, suppose that an explanation is being used to identify ways in which a denied loan appli- 
a6 cant could improve their application. Then, it may be important that the explanation only include 
27 factors that, if changed, would change the outcome. In contrast, suppose the explanation is being used 
2g to determine if the denial was fair. Then, it may be important that the explanation does not leave out 
29 any relevant factors. In this way, properties serve as a glue between interpretability methods, contexts 
30and end-tasks: properties allow us to specify and quantify aspects of the explanation relevant to our 
31 ultimate end-task goals. Then we can make sure that our interpretability method has those properties. 
32 

33 How they all relate. Formulating an interpretable machine learning problem generally starts by 
34specifying the context and the end-task. Together the context and the end-task imply what metrics 
35 are appropriate to evaluate the downstream performance on the end-task and suggest what properties 
36 will be important in the explanation. Meanwhile, the context also determines the data and training 
37metric for the ML model. The appropriate choice of explanation methods will depend on the model 
33and properties desired, and it will be evaluated with respect to the end-task metric to determine the 
39 downstream performance. Figure 33.1 shows these relationships. 

40 Interpretable machine learning involves many challenges, from computing explanations and op- 
41 timizing interpretable models and creating explanations with certain properties to understanding 
a2the associated human factors. That said, the grand challenge is to (1) understand what properties 
agare needed for different contexts and end-tasks and (2) identify and create interpretable machxine 
44learning methods that have those properties. 

45 

46 A Simple Example In the following sections, we will expand upon methods for interpretability, 
47 
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Figure 33.1: The Interpretable Machine Learning ecosystem. While standard machine learning can often 
abstract away elements of the context and consider only the process of learning models given a data distribution 
and a loss, interpretable machine is inextricably tied to a socio-technical context. 


Downstream 
Task Loss 


Training 
Loss 


metrics for evaluation, and types of properties. First, however, we provide a simple example connecting 
all of the concepts we discussed above. 

Suppose our context is that we have a lemonade stand, and our end-task is to understand when the 
stand is most successful in order to prioritize which days it is worth setting it up. (We have heard 
that sometimes machine learning models latch on to incorrect mechanisms and want to check the 
model before using it to inform our business strategy.) Our metric for the downstream performance 
is whether we correctly determine if the model can be trusted; this could be quantified as the amount 
of profit that we make by opening on busy days and being closed on quiet days. 

To train our model, we collect data on two input features—the average temperature for the day 
(measured in degrees Fahrenheit) and the cleanliness of the sidewalk near our stand (measured as a 
proportion of the sidewalk that is free of litter, between 0 and 1)—and the output feature of whether 
the day was profitable. Two models seem to fit the data approximately equally well: 


Model 1: 
p(profit) = .9 x (temperature > 75) + .1(howCleanSidewalk) (33.1) 
Model 2: 
p(profit) = o(.9(temperature — 75)/maxTemperature + .1(howCleanSidewalk — .5)) (33.2) 


These models are illustrated in Figure 33.2. Both of these models are inherently interpretable 
in the sense that they are easy to inspect and understand. While we were not explicitly seeking 
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~ Figure 33.2: Models described in the simple example. Both of these models have the same qualitative 
— characteristics, but different explanation methods will describe these models quite differently, potentially 
B causing confusion. 
4 
15 
16 
17 causal models (for that, see Main Chapter 36), both rely mostly on the temperature, which seems 
18 reasonable. 


19 For the sake of this example, suppose that the models above were black boxes, and we could only 
20request partial views of it. We decide to ask the model for the most important features. Let us see 
21what happens when we consider two different ways of computing important features.” 

22 Our first (feature-based) explanation method computes, for each training point, whether individually 
23 changing each feature to its max or min value changes the prediction. Important features are those 
24that change the prediction for many training points. One can think of this explanation method as a 
25 variant of computing feature importance based on how important a feature is to the coalition that 
26 produces the prediction. In this case, both models will report temperature to be the dominating 
27feature. If we used this explanation to vet our models, we would correctly conclude that both models 
28use the features in a sensible way (and thus may be worth considering for deciding when to open our 
29lemonade stand). 

30 Our second (feature-based) explanation method computes the magnitude of the derivatives of the 
3loutput with respect to the inputs for each training point. Important features are those that have a 
32large sum of absolute derivatives across the training set. One can think of this explanation method as 
33a variant of computing feature importances based on local geometry. In this case, Model 2 will still 
34report that temperature has higher derivatives. However, Model 1, which has very similar behavior 
35to Model 2, will report that sidewalk cleanliness is the dominating feature because the derivative 
36 with respect to temperature is zero nearly everywhere. If we used this explanation to vet our models, 
37we would incorrectly conclude that Model 1 relies on an unimportant feature (and that Model 1 and 
382 rely on different features). 

39 What happened? The different explanations had different properties. The first explanation had 
40the property of fidelity with respect to identifying features that, if changed, will affect the prediction, 
41 whereas the second explanation had the property of correctly identifying features that have the most 
42local curvature. In this example, the first property is more important for the task of determining 
43 whether our model can be used to determine our business strategy. ° 

44 

452, In the remainder of the chapter we will describe many other ways of creating and computing explanations. 

463. Other properties may be important for this end-task. This example is just the simplest one. 
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33.2 Methods for Interpretable Machine Learning 


There exist many methods for interpretable machine learning. Each method has different properties 
and the right choice will depend on context and end-tasks. As we noted in Section 33.1.2, the grand 
challenge in interpretable machine learning is determining what kinds of properties are needed for 
what contexts, and what explanation methods satisfy those properties. Thus, one should consider 
this section a high-level snapshot of the rapidly changing options of methods that one may want to 
choose for interpretable machine learning. 


33.2.1 Inherently Interpretable Models: The Model is its Explanation 


We consider certain classes of models inherently interpretable: a person can inspect the full model and, 
with reasonable effort, understand how inputs become outputs.* Specifically, we define inherently 
interpretable models as those that require no additional process or proxies in order for them to be 
used as explanation for the end-task. For example, suppose a model consists of a relatively small set 
of rules. Then, those rules might suffice as the explanation for end-tasks that do not involve extreme 
time pressure. (Note: in this way, a model might be inherently interpretable for one end-task and 
not another.) 

Inherently interpretable models fall into two main categories: sparse (or otherwise compact) models 
and logic-based models. 


Compact or sparse feature-based models include various kinds of sparse regressions. Earlier in 
this textbook, we discussed simple models such as HMMs (Main Section 29.2), generalized linear 
models (Main Chapter 15), and various latent variable models (Main Chapter 28). When small 
enough, these are generally inherently interpretable. More advanced models in this category include 
super-sparse linear integer models and other checklist models [DMV15; UTR14]. 

While simple functionally, sparsity has its drawbacks when it comes to inspection and interpretation. 
For example, if a model picks only one of several correlated features, it may be harder to identify 
what signal is actually driving the prediction. A model might also assign correlated features different 
signs that ultimately cancel, rendering an interpretion of weights meaningless. 

To handle these issues, as well as to express more complex functions, some models in this category 
impose hierarchical or modular structures in which each component is still relatively compact and 
can be inspected. Examples include topic models (e.g. [BNJ03b], (small) discrete time series models 
(e.g. [FHDV20]), generalized additive models (e.g. [HT17]) and monotonicity-enforced models (e.g., 
[Gup+16]). 

Logic-based models use logical statements as basis. Models in this category include decision- 
trees [Bre+17], decision lists [Riv87; WR15; Let+15a; Ang-+18; DMV15] , decision tables, decision 
sets [Hau+10; Wan+17a; LBL16; Mal+17; Bén+21] and logic programming [MDR94]. A broader 
discussion, as well as a survey of user studies on these methods, can be found in [Fre14]. Logic-based 
models easily model non-linear relationships but can have trouble modeling continuous relationships 
between the input and output (e.g. expressing a linear function vs. a step-wise constant function). 
Like the compact models, hierarchies and other forms of modularity can be used to extend the 
expressivity of the model while keeping it human-inspectable. For example, one can define a new 


4. There may be other questions, such as how training data influenced the model, which may still require additional 
computation or information. 
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concept as a formula based on some literals, and then use the new concept to build more complex 
rules. 

When using inherently interpretable models, three key decisions need to be made: the choice of 
the model class, how to manage uninterpretable input features, and the choice of optimization method. 


Decision: Model Class. Since the model is its own explanation, the decision on the model class 
becomes the decision on the form of explanation. Thus, we need to consider both whether the model 
class is a good choice for modeling the data as well as providing the necessary information to the 
io user. For example, if one chooses to use a linear model to describe one’s data, then it is important 
uthat the intended users can understand or manipulate the linear model. Moreover, if the fitting 
12process produces a model that is too large to be human-inspectable, then it is no longer inherently 
13interpretable, even if it belongs to one of the model classes described above. 
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Decision: Optimization Methods for Training. The kinds of model classes that are typically 
iginherently interpretable often require more advanced methods for optimization: compact, sparse, and 
1zlogic-based models all involve learning discrete parameters. Fortunately, there is a long and contin- 
iguing history of research for optimizing such models, including directly via optimization programs, 
igrelaxation and rounding techniques, and search-based approaches. Another popular optimization 
20 approach is via distillation or mimics: one first trains a complex model (e.g., a neural network) and 
ai then uses the complex model’s output to train a simpler model to mimic the more complex model. 
22 The more complex model is then discarded. These optimization techniques are beyond the scope of 
23this chapter but covered in optimization textbooks. 


24 

25 Decision: How to Manage Uninterpretable Input Features. Sometimes the input features 
26 themselves are not directly interpretable (e.g. pixels of an image or individual amplitudes spectrogram); 
270nly collections of inputs have semantic meaning for human users. This situation challenges our 
ag ability not only to create inherently interpretable models but also explanations in general. 

29 To address this issue, more advanced methods attempt to add a “concept” layer that first converts 
30the uninterpretable raw input to a set of human-interpretable concept features. Next, these concepts 
31are mapped to the model’s output [Kim+18a; Bau+17]. This second stage can still be inherently 
32interpretable. For example, one could first map a pattern of spectrogram to a semantically meaningful 
33s0und (e.g., people chatting, cups clinking) and then from those sounds to a scene classification (e.g., 
3ain a cafe). While promising, one must ensure that the initial data-to-concept mapping truly maps 
35 the raw data to concepts as the user understands them, no more and no less. Creating and validating 
36that machine-derived concepts correspond to a semantically meaningful human concepts remains an 
370pen research challenge. 

38 

39 When might we want to consider inherently interpretable models? When not? Inher- 
ao ently interpretable models have several advantages over other approaches. When the model is its 
41 explanation, one need not worry about whether the explanation is faithful to the model or whether it 
42 provides the right partial view of the model for the intended task. Relatedly, if a person vets the 
43model and finds nothing amiss, they might feel more confident about avoiding surprises. For all these 
aareasons, inherently interpretable models have been advocated for in high-stakes scenarios, as well as 
45 generally being the first go-to to try[Rud19]. 

46 That said, these models do have their drawbacks. They typically require more specialized 
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optimization approaches. With appropriate optimization, inherently interpretable models can 
often match the performance of more complex models, but there are domains—in particular, images, 
waveforms, and language—in which deep models or other more complex models typically give 
significantly higher performance. Trying to fit complex behavior with a simple function may result 
not only in high bias in the trained model but also invite people to (incorrectly) rationalize why that 
highly biased model is sensible [Lun+20]. In an industry setting, seeking a migration away from a 
legacy, uninterpretable, business-critical model that has been tuned over decades would run into 
resistance. 

Lastly, we note that just because a model is inherently interpretable, it does not guard against all 
kinds of surprises: as noted in Section 33.1, interpretability is just one form of validation mechanism. 
For example, if the data distribution shifts, then one may observe unexpected model behavior. 


33.2.2  Semi-Inherently Interpretable Models: Example-Based Methods 


Example-based models use examples as their basis for their predictions. For example, an example- 
based classifier might predict the class of a new input by first identifying the outputs for similar 
instances in the training set and next taking a vote. K-nearest neighbors is one of the best 
known models in this class. Extensions include methods to identify exemplars for predicted classes 
and clusters (e.g. [KRS14; KL17b; JL15a; FDO7b|[RT16; Arn+10]), to generate exemplars (e.g. 
[Li+17d]), to define similarities between instances via sophisticated embeddings (e.g.[PM18a]), and 
to first decompose an instance into parts and then find neighbors or exemplars between the parts 
(e.g. [Che+18b]). Like logic-based models, example-based models can describe highly non-linear 
boundaries. 

On one hand, individual decisions made by example-based methods seem fully inspectable: one can 
provide the user with exactly the training instances (including their output labels) that were used to 
classify a particular input in a particular way. However, it may be difficult to convey a potentially 
complex distance metric used to define “similarity.” As a result, the user may incorrectly infer what 
features or patterns made examples similar. It is also often difficult to convey the intuition behind 
the global decision boundary using examples. 


33.2.3 Post-hoc or Joint training: The Explanation gives a Partial View of the 
Model 


Inherently interpretable models are a subset of all machine learning models, and circumstances may 
require working with a model that is not inherently interpretable. As noted above, large neural 
models (Main Chapter 16) have demonstrated large performance benefits for certain kinds of data 
(e.g. images, waveform, and text); one might have to work with a legacy, business critical model that 
has been tuned for decades; one might be trying to understand a system of interconnected models. 

In these cases, the view that the explanation gives into the model will necessarily be partial: the 
explanation may only be an approximation of the model or be otherwise incomplete. Thus, more 
decisions have to be made. Below, we split these decisions into two broad categories—what the 
explanation should consist of to best serve the context and how the explanation should be computed 
from the trained model. More detail on the abilities and limitations of these partial explanation 
methods can be found in [Sla+20; Yeh+19a; Kin+19; Ade+20a]. 
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33.2.3.1 What does the explanation consist of? 


One set of decisions center around the content of the explanation and what properties it should have. 
One choice is the form: Should the explanation be a list of important features? The top interactions? 
One must also choose the scope of the explanation: Is it trying to explain the whole model (global)? 
The model’s behavior near a specific input (local)? Something else? Determining what properties 
the explanation must have will help answer these and other questions. We expand on each of these 
points below; the right choice, as always, will depend on the user—whom the explanation is for—and 
10their end-task. 
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1 
12 Decision: Form of the Explanation. In the case of inherently interpretable models, the model 
13class used to fit the data was also the explanation. Now, the model class and the explanation are two 
14different entities. For example, the model could be a deep network and the explanation a decision 
15 tree. 

16 Works in interpretable machine learning have used a large variety of forms of explanations. The 
izform could be a list of “important” input features [RSG16b; Lun+20; STY17; Smi+17; FV17] or 
18“important” concepts [Kim+18a; Bau+20; Bau+18]. Or it could be a simpler model that approximates 
19the complex model (e.g. a local linear approximation, an approximating rule set)[FH17; BKB17; 
20 Aga+21b; Yin+19c]. Another choice could be a set of similar or prototypical examples [KRS14; 
21AA18; Li+17d; JL15a; JL15b; Arn+10]. Finally, one can choose whether the explanation should 
22include a contrast against an alternative (also sometimes described as a counterfactual explanation) 
23[Goy+19; WMR18; Kar+20a] or include or influential examples [KL17b]. 

24 Different forms of explanations will facilitate different tasks in different contexts. For example, a 
25contrastive explanation of why treatment A is better than treatment B may help a clinician decide 
26 between treatments A and B. However, a contrast between treatments A and B may not be useful 
27when comparing treatments A and C. Given the large number of choices, literature on how people 
28communicate in the desired context can often provide some guidance. For example, if the domain 
29involves making quick, high-stakes decisions, one might turn to how trauma nurses and firefighters 
30explain their decisions (known as recognition-primed decision making, [Kle17]). 

31 Decision: Scope of the Explanation: Global or Local. Another major decision regarding 
32the parameters of the explanation is its scope. 

33 Local explanation: In some cases, we may only need to interrogate an existing model about 
34a specific decision. For example, why was this image predicted as a bird? Why was this patient 
35 predicted to have diabetes? Local explanations can help see if a consequential decision was made 
36incorrectly or determine what could have been done differently to produce a different outcome (i.e., 
37 provide a recourse). 

38 Local explanations can take many forms. A family of methods called saliency maps or attribution 
39maps [STY17; Smi+17; ZF14; Sel+17; Erh+09; Spr+14; Shr+16] estimate the importances of each 
40input dimension (e.g. via first-order derivatives with respect to the input). More generally, one might 
41locally-fit simpler model in the neighborhood of the input of interest (e.g. LIME [RSG16b]). A local 
42explanation may also consist of representative examples, including identifying which training points 
43 were most influential for a particular decision [KL17b] or identifying nearby data points with different 
44predictions [MRW19; LHR20; Kar+20a]. 

45 All local explanation methods are partial views because they only attempt to explain the model 
46 around an input of interest. A key risk is that the user may overgeneralize the explanation to a wider 
47 
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region than it applies. They may also interpolate an incorrect mental model of the model based on a 
few local explanations. 


Global explanation: In other cases, we may desire insight into the model as a whole or for a 
collection of data points (e.g., all inputs predicted to one class). For example, suppose that our 
end-task is to decide whether to deploy a model. Then, we care about understanding the entire 
model. 

Global explanations can take many forms. One choice is to fit a simpler model (e.g. an inherently 
interpretable model) that approximates the original model (e.g. [HVD14]). One can also identify 
concepts or features that affect decisions across many inputs (e.g. [Kim+18b]). Another approach 
is to provide a carefully chosen set of representative examples[Yeh+18]. These examples might be 
chosen to be somehow characteristic of, or providing coverage of, a class (e.g. [AA18]), to draw 
attention to decision boundaries (e.g. [Zhu+18]), or to identify inputs particularly influential in 
training the model. 

Unless a model is inherently interpretable, it is still important to remember that a global explanation 
is still a partial view. To make a complex model accessible to the user, the global explanation will 
need to leave some things out. 

Decision: Determining what Properties the Context Needs. Different forms of explana- 
tions have have different levels of expressivity. For example, an explanation listing important features, 
or fitting a local linear model around a particular input, does not expose interactions—but fitting a 
local decision tree would. For each form, there will also be many ways to compute an explanation of 
that form (more on this in Section 33.2.3.2). How do we choose amongst all of these different ways to 
compute the explanation? We suggest that the first step in determining the form and computation 
of an explanation should be to determine what properties are needed from it. 

Specifying properties is especially important because not only may different forms of explanations 
have different intrinsic properties—e.g. can it model interactions?—but the properties may depend 
on the model being explained. For example, if the model is relatively smooth, then a feature-based 
explanation relying on local gradients may be fairly faithful to the original model. However if the 
model has spiky contours, the same explanation may not adequately capture the model’s behavior. 
Once the desired properties are determined, one can determine what kind of computation is necessary 
to achieve them. We will list commonly desirable properties in Section 33.3. 


33.2.3.2 How the explanation is computed 


Another set of decisions has to do with how the explanation is computed. 


Decision: Computation of Explanation. Once we make the decisions above, we must decide 
how the explanation will actually be computed. This choice will have a large impact on the 
explanation’s properties. Thus, it is crucial to carefully choose a computational approach that 
provides the properties needed for the context and end-task. 

For example, suppose one is seeking to identify the most “important" input features that change a 
prediction. Different computations correspond to different definitions of importance. One definition 
of importance might be the smallest region in an image that, when changed, changes the prediction— 
a perturbation-based analysis. Even within this definition, we would need to specify how that 
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perturbation will be computed: Do we keep the pixel values within the training distribution? Do we 
preserve correlations between pixels? Different works take different approaches [SVZ13; DG17; FV17; 
DSZ16; Adl+18; Bac+15a]. 

A related approach is to define importance in terms of sensitivity (e.g., largest gradients of the 
output with respect to the input feature). Even then, there are many computational decisions to be 
made [STY17; Smi+17; Sel+17; Erh+09; Shr+16]. Yet another common definition of importance is 
how often the input feature is part of a “winning coalition” that drives the prediction, e.g. a Shapley 
or Banzaf score[LL17]. Each of these definitions have different properties, as well as require different 
19 amounts of computation. 

11 Similar issues come up with other forms of explanations. For example, for an example-based 
12explanation, one has to define what it means to be similar or otherwise representative: Is it the cosine 
i3Similarity between activations? A uniform L2 ball of a certain size between inputs? Likewise, there 
14are many different ways to obtain counterfactuals. One can rely on distance functions to identify 
isnearby inputs that with different outputs [WMR17; LHR20], causal frameworks [Kus+18], or SAT 
16 formulations |[Kar+20a], among other choices. 

iz Decision: Joint Training vs. Post-hoc Application. So far, we have described our partial 
igexplanation techniques as extracting some information from an already-trained model. This approach 
igis called deriving a post-hoc explanation. As noted above, post-hoc, partial explanations may have 
20some limitations: for example, an explanation based on a local linear approximation may be great 
aif the model is generally smooth, but provide little insight if the model has high curvature. Note 
22that this limitation is not because the partial explanation is wrong, but because the view that local 
23 gradients provide isn’t sufficient if the true decision boundary is curvy. 

24 One approach to getting explanations to have desired properties we is to train the model and the 
25explanation jointly. For example, a regularizer that penalizes violations of desired properties can 
26 help steer the overall optimization process towards learning models that both perform well and are 
27amenable to the desired explanation[Plu+20]. It is often possible to find such a model because most 
28 complex model classes have multiple high-performing optima [Bre01]. 

29 The choice of regularization will depend on the desired properties, the form of the explanation, 
30and its computation. For example, in some settings, we may desire the explanation use the same 
31 features that people do for the task (e.g., lower frequency vs. higher frequency features in image 
32 Classifiers [Wan+20b])—and still be faithful to the model. In other settings, we may want to control 
33the input dimensions used or not used in the explanation, or for the explanation to be somehow 
34compact (e.g. a small decision tree) while still being faithful to the underlying model, [RHDV17; 
35 Shu+19a; Vel+17; Nei+18; Wu+19b; Plu+20]. (Certain attention models fall into this category 
36[JW19; WP 19].) We may also have constraints on the properties of concepts or other intermediate 
37 features [AMJ18b; Koh+20a; Hen+16; BH20; CBR20; Don+17b]. In all of these cases, these desired 
33 properties could be included as a regularizer when training the model. 

39 When choosing between a post-hoc explanation or joint training, one key consideration is that 
4ojoint training assumes that one can re-train the model or the system of interest. In many cases in 
41 practice, this may not be possible. Replacing a complex and well-validated system in deployment for 
42a decade may not be possible or take a prohibitively long time. In that case, one can still extract 
43 approximated explanations using post-hoc methods. Finally, a joint optimization, even when it can 
44be performed, is not a panacea: optimization for some properties may result in unexpected violations 
45 0f other (unspecified but desired) properties. For this reason, explanations from jointly trained model 
ae are still partial. 
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When might we want to consider post-hoc methods, and when not?. The advantage of 
post-hoc interpretability methods is that they can be applied to any model. This family of methods 
is especially useful in real-world scenarios where one needs to work with a system that contains many 
models as its parts, where one cannot expect to replace the whole system with one model. These 
approaches can also provide at least some broader knowledge about the model to identify unexpected 
concerns. 

That said, post-hoc explanations, as approximations of the true model, may not be fully faithful to 
the model nor cover the model completely. As such, an explanation method tailored for one context 
may not be transferable in another; even in the intended context, there may be blindspots about 
the model that the explanation misses completely. For these reasons, in high stakes situations, one 
should attempt to use an inherently interpretable model first if possible [Rud19]. In all situations 
when post-hoc explanations are used, one must keep in mind that they are only one tool in a broader 
accountability toolkit and warn users appropriately. 


33.2.4 Transparency and Visualization 


The scope of interpretable machine learning is around methods that expose the process by which a 
trained model makes a decision. However, the behavior of a model also depends on the objective 
function, the training data, how the training data were collected and processed, and how the model 
was trained and tested. Conveying to a human these other aspects of what goes into the creation 
of a model can be as important as explaining the trained model itself. While a full discussion of 
transparency and visualization is outside the scope of this chapter, we provide a brief discussion here 
to describe these important adjacent concepts. 


Transparency is an umbrella term for the many things that one could expose about the modeling 
process and its context. Interpreting models is one aspect. However, one could also be transparent 
about other aspects, such as the data collection process or the training process (e.g. [Geb+21; 
Mit+19; Dnp]). There are also situations in which a trained model is released (whether or not it is 
inherently interpretable), and thus the software can be inspected and run directly. 


Visualization is one way to create transparency. One can visualize the data directly or various 
aspects of the model’s process (e.g. [Str+17]). Interactive visualizations can convey more than text 
or code descriptions [ZF 14; OMS17; MOT15; Ngu+16; Hoh+20]. Finally, in the specific context of 
interpretable machine learning, how the explanation is presented—the visualization—can make a 
large difference in how easily users can consume it. Even something as simple as a rule list has many 
choices of layout, highlighting, and other organization. 


When might we want to consider transparency and visualization? When not? In many 
cases, the trouble with a model comes not from the model itself, but parts of its training pipeline. 
The problem might be the training data. For example, since policing data contain historical bias, 
predictions of crime hot spots based on that data will be biased. Similarly, if clinicians only order 
tests when they are concerned about a patient’s condition, then a model trained to predict risk based 
on tests ordered will only recapitulate what the clinicians already know. Transparency about the 
properties of the data, and how training and testing were performed, can help identify these issues. 
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Of course, inspecting the data and the model generation process is something that takes time 
and attention. Thus, visualizations of this kind and other descriptions to increase transparency 
are best-suited to situations in which a human inspector is not under time pressure to sift through 
potentially complex patterns for sources of trouble. These methods are not well-suited for situations 
in which a specific decision must be made in a relatively short amount of time, e.g. providing 
decision-support to a clinician at the bedside. 

Finally, transparency in the form of making code available can potentially assist in understanding 
how a model works, identifying bugs, and allowing independent testing by a third party (e.g., testing 
io With a new set of inputs, evaluating counterfactuals in different testing distributions). However, if a 
11model is sufficiently complex, as many modern models are, then simply having access to the code is 
iglikely not be enough for a human to gain sufficient understanding for their task. 
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Recall from the Terminology and Framework in Section 33.1.2 that the context and end-task determine 
18 what properties are needed for the explanation. For example, in a high-stakes setting—such as advising 
190n interventions for an unstable patient—it may be important that the explanation completely and 
20accurately reflects the model (fidelity). In contrast, in a discovery-oriented setting, it might be more 
2limportant for any explanation to allow for efficient iterative refinement, revealing different aspects of 
22the model in turn (interactivity). Not all contexts and end-tasks need all properties, and the lack of 
23a key property may result in poor downstream performance. 

24 While the research is still evolving, there exists a growing informal understanding about how 
25 properties may work as an abstraction between methods and contexts. Many interpretability methods 
26from Section 33.2 share the same properties, and methods with the same properties may have similar 
27downstream performance in a specific end-task and context. If two contexts and end-tasks require 
28the same properties, then a method that works well for one may work well for the other. A method 
29with properties well-matched for one context could miserably fail in another context. 

30 

31 How to find desired properties? Of course, identifying what properties are important for a 
32 particular context and end-task is not trivial. Indeed, identifying what properties are important for 
33what contexts, end-tasks, and downstream performance metrics is one facet of the grand challenge of 
34interpretable machine learning. For the present, the process of identifying the correct properties will 
35 likely require iteration via user studies. However, iterating over properties is still a much smaller 
36space than iterating over methods. For example, if one wants to test whether the sparsity of the 
37explanation is key to good downstream performance, one could intentionally create explanations of 
38varying levels of sparsity to test that hypothesis. This is a much more precise knob to test than 
39exhaustively trying out different explanation methods with different hyperparameters. 

40 Below, we first describe examples of properties that have been discussed in the interpretable 
41machine learning literature. Many of these properties are purely computational—that is, they can 
42be determined purely from the model and the explanation. A few have some user-centric elements. 
43 Next we list examples of properties of explanation from cognitive science (on human to human 
44explanations) and human-computer interaction (on machine to human explanations). Some of these 
45 properties have analogs in the machine learning list, while others may serve as inspiration for areas 
46 to formalize. 
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33.3. PROPERTIES: THE ABSTRACTION BETWEEN CONTEXT AND METHOD 


33.3.1 Properties of Explanations from Interpretable Machine Learning 


Many lists of potentially-important properties of interpretable machine learning models have been 
compiled, sometimes using different terms for similar concepts and sometimes using the similar terms 
for different concepts. Below we list some commonly-described properties of explanations, knowing 
that this list will evolve over time as the field advances. 


Faithfulness, Fidelity (e.g. as described in [JG20; JG21]). When the explanation is only a 
partial view of the model, how well does it match the model? There are many ways to make this 
notion precise. For example, suppose a mimic (simple model) is used to provide a global explanation 
of a more complex model. One possible measure of faithfulness could be how often the mimic gives 
the same outputs as the original. Another could be how often the mimic has the same first derivatives 
(local slope) as the original. In the context of a local explanation consisting of the ‘key’ features for a 
prediction, one could measure faithfulness by whether the prediction changes if the supposedly impor- 
tant features are flipped. Another measure could check to make sure the prediction does not change if 
a supposedly unimportant feature is flipped. The appropriate formalization will depend on the context. 


Compactness, Sparsity (e.g. as described in [Lip18; Mur+19] ). In general, an explanation 
must be small enough such that the user can process it within the constraints of the task (e.g. how 
quickly a decision must be made). Sparsity generally corresponds to some notion of smallness (a few 
features, a few parameters, L1 norm etc.). Compactness generally carries an additional notion of 
not including anything irrelevant (that is, even if the explanation is small enough, it could be made 
smaller). Each must be formalized for the context. 


Completeness (e.g. as described in [Yeh+19b]). If the explanation is not the model, does it still 
include all of the relevant elements? For example, if an explanation consists of important features 
for a prediction, does it include all of them, or leave some out? Moreover, if the explanation uses 
derived quantities that are not the raw input features—for example, some notion of higher-level 
concepts—are they expressive enough to explain all possible directions of variation that could change 
the prediction? Note that one can have a faithful explanation in certain ways but not complete in 
others: Fore example, an explanation may be faithful in the sense that flipping features considered 
important flips the prediction and flipping features considered unimportant does not. However, the 
explanation may fail to include that flipping a set of unimportant features does change the prediction. 


Stability (e.g. as described in [AMJ18a]) To what extent are the explanations similar for similar 
inputs? Note that the underlying model will naturally affect whether the explanation can be stable. 
For example, if the underlying model has high curvature and the explanation has limited expressive- 
ness, then it may not be possible to have a stable explanation. 


Actionability (e.g. as described in [Kar+20b; Poy+20]). Actionability implies filtering the 
content of the explanation to focus on only aspects of the model that the user might be able to 
intervene on. For example, if a patient is predicted to be at high risk of heart disease, an actionable 
explanation might only include mutable factors such as exercise and not immutable factors such as 
age or genetics. The notion of recourse corresponds to actionability in a justice context. 
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Modularity (e.g. as described in [Lip18; Mur+19]). Modularity implies that the explanation can 
be broken down into understandable parts. While modularity does not guarantee that the user can 
explain the system as a whole, for more complex models, modular explanations—where the user can 
inspect each part—can be an effective way to provide a reasonable level of insight into the model’s 
workings. 


Interactivity. (e.g., [Ten+20]) Does the explanation allow the user to ask questions, such as how 
the explanation changes for a related input, or how an output changes given a change in input? In 
1980me contexts, providing everything that a user might want or need to know from the start might be 
11 0verwhelming, but it might be possible to provide a way for the user to navigate the information 
igabout the model in their own way. 
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Translucence (e.g. as described in [SF 20; Lia+19]). Is the explanation clear about its limitations? 
For example, if a linear model is locally fit to a deep model at a particular input, is there a mechanism 
that reports that this explanation may be limited if there are strong feature interactions around 
that input? We emphasize that translucence is about exposing limitations in the explanation, rather 
than the model. As with all accountability methods, the goal of the explanation is to expose the 
limitations of the model. 
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21 Simulability (e.g. as described in [Lip18; Mur+19]). A model is simulable if a user can take 
22the model and an input and compute the output (within any constraints of time and cognition). A 
23Simulable explanation is an explanation that is a simulable model. For example, a list of features 
24is not simulable, because a list of features alone does not tell us how to compute the output. In 
g5contrast, an explanation in the form of a decision tree does include a computation process: the 
26 USET can follow the logic of the tree, as long as it is not too deep. This example also points out an 
27important difference between compactness and simulability: if an explanation is too large, it may not 
2g be simulable. However, just because an explanation is compact—such as a short list of features—does 
g9not mean that a person can compute the model’s output with it. 

30 It may seem that simulability is different from the other properties because its definition involves 
31 human input. However, in practice, we often know what kinds of explanations are easy for people 
32to simulate (e.g. decision trees with short path lengths, rule lists with small formulas, etc.). This 
33 knowledge can be turned into a purely computational training constraint where we seek simulatable 
34explanations. 

35 

36 Alignment to the User’s Vocabulary and Mental Model. (e.g., as described in [Kim+18a]) 
37Is the content of the explanation designed for the user’s vocabulary? For example, the explanation 
33 could be given in the semantics a user knows, such as medical conditions vs. raw sensor readings. 
39 Doing so can help the user more easily connect the explanation to their knowledge and existing 
4o decision-making guidelines [Clo+19]. Of course, the right vocabulary will depend on the user: an 
41 explanation in terms of parameter variances and influential points may be comprehensible to an 
ag engineer debugging a lending model but not to a loan applicant. 

43 Like simulability, mental-model alignment is more human-centric. However, just as before, we can 
aaimagine an abstraction between eliciting vocabulary and mental models from users (i.e., determining 
45 how they define their terms and how to think), and ensuring that an explanation is provided in 
46 alignment with whatever that elicited user vocabulary and mental model is. 
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33.3. PROPERTIES: THE ABSTRACTION BETWEEN CONTEXT AND METHOD 


Once desired properties are identified, we need to operationalize them. For example, if sparsity is a 
desired property, would using the L1 norm be enough? Or does a more sophisticated loss term need 
to be designed? This decision will necessarily be human-centric: how small an explanation needs to 
be, or in what ways it needs to be small, is a decision that needs to consider how people will be using 
the explanation. Once operationalized, most properties can be optimized computationally. That 
said, the properties should be evaluated with the context, end-task, model and chosen explanation 
methods. Once evaluated, one may revisit the choice of the explanation and model. 

Finally, we emphasize that the ability to achieve a particular property will depend on the intrinsic 
characteristics of the model. For example, the behavior of a highly nonlinear model with interactions 
between the inputs will, in general, be harder to understand than a linear model. No matter how we 
try to explain it, if we are trying to explain something complicated, then users will have a harder 
time understanding it. 


33.3.2 Properties of Explanations from Cognitive Science 


Above we focused on computational properties between models and explanations. The fields of 
cognitive science and human-computer interaction have long examined what people consider good 
properties of an explanation. These more human-centered properties may be ones that researchers in 
machine learning may be less aware of, yet essential for communicating information to people. 

Unsurprisingly, the literature on human explanation concurs that the explanation must fit the 
context [VF +80]; different contexts require different properties and different explanations. That said, 
human explanations are also social constructs, often including post-hoc rationalizations and other 
biases. We should focus on properties that help users achieve their goals, not ones simply “because 
people sometimes do it.” 

Below we list several of these properties. 


Soundness (e.g., as described in [Kul+13]). Explanations should contain nothing but the truth 
with respect to whatever they are describing. Soundness corresponds to notions of compactness and 
faithfulness above. 


Completeness (e.g., as described in [Kul+13]). Explanations should contain the whole truth 
with respect to whatever they are describing. Completeness corresponds to notions of completeness 
and faithfulness above. 


Generality (e.g., as described in [Mil19]). Overall, people understand that an explanation for one 
context may not apply in another. That said, there is an expectation that an explanation should 
reflect some underlying mechanism or principle and will thus apply to similar cases—for whatever 
notion of similarity is in the person’s mental model. Explanations that do not generalize to similar 
cases may be misinterpreted. Generality corresponds to notions of stability above. 


Simplicity (e.g., as described in [Mil19]). All of the above being equal, simpler explanations are 
generally preferred. Simplicity relates to notions of sparsity and complexity above. 


Contrastiveness (e.g., as described in [Mil19]). Contrastive explanations provide information of 
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how something differs from an alternate decision or prediction. For example, instead of providing a 
list of features for why a particular drug is recommended, it might provide a list of features that 
explain why one drug is recommended over another. Contrastiveness relates to notions of actionability 
above, and more generally explanation types that include counterfactuals. 


Finally, the cognitive science literature also notes that explanations are often goal directed. This 
matches the notion of explanation in ML as information that helps a person improve performance on 
their end-task. Different information may help with different goals, and thus human explanations 
iotake many forms. Examples include deductive-nomological forms (i.e. a logical proofs) [HO48], forms 
i1that provide a sense of an underlying mechanism [BA05; Gle02; CO06], and forms that conveying 
igunderstanding [Kei06]. Knowing these forms can help us consider what options might be best among 
13 different sets of interpretable machine learning methods. 
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33.4 Evaluation of Interpretable Machine Learning Models 
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20 One cannot formalize the notion of interpretability without specifying the context, the end-task, and 
21the downstream performance metric [VF +80]. If one explanation empowers the human to get better 
22 performance on their end-task over another explanation, then it is more useful. While the grand 
23 challenge of interpretable machine learning is to develop a general understanding of what properties 
24are needed for good downstream performance on different end-tasks in different contexts, in this 
25section, we will focus on rigorous evaluation within one context [DVK17]. 

26 Specifically, we describe two major categories for evaluating interpretable machine learning methods: 
27 

23 Computational evaluations of properties (without people). Computational evaluations of 
29whether explanations have desired properties do not user studies. For example, one can computation- 
30ally measure whether a particular explanation satisfies a definition of faithfulness under different 
31training and test data distributions or whether the outputs of one explanation are more sparse 
32than another. Such measures are valuable when one already knows that certain properties may be 
33important for certain contexts. Computational evaluations also serve as intermediate evaluations and 
34sanity checks to identify undesirable explanation behavior prior to a more expensive user study-based 
35 evaluation. 

36 

37 User studies (with people). Ultimately, user studies are needed to measure how well an 
38interpretable machine learning method enables the user to complete their end-task in a given context. 
39 Performing a rigorous, well-designed user study in a real context is significant work—much more 
40so than computing a test likelihood on benchmark datasets. It requires significant asks of not only 
4ithe researchers but also the target users. Methods for different contexts will also have different 
42evaluation challenges: while a system designed to assist with optimizing music recommendations 
43might be testable on a wide population, a system designed to help a particle physicist identify 
44new kinds of interactions might only be tested with one or two physicists because people with that 
45expertise are hard to find. In all cases, the evaluation can be done rigorously given careful attention 
46to experimental design. 
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33.4. EVALUATION OF INTERPRETABLE MACHINE LEARNING MODELS 


33.4.1 Computational Evaluation: Does the Method have Desired Properties? 


While the ultimate measure of interpretability is whether the method successfully empowers the 
user to perform their task, properties can serve as a valuable abstraction. Checking whether an 
explanation has the right computational and desired properties can ensure that the method works as 
expected (e.g., no implementation errors, no obviously odd behaviors). One can iterate on novel, 
computationally-efficient methods to optimize the quantitative formalization of a property before 
conducting expensive human experiments. Computational checks can also ensure whether properties 
that held for one model continue to hold when applied to another model. Finally, checking for specific 
properties can also help pinpoint in what way an explanation is falling short, which may be less clear 
from a user study due to confounding. 

In some cases, one might be able to prove mathematically that an explanation has certain properties, 
while in others the test must be empirical. For empirical testing, one umbrella strategy is to use 
a hypothesis-based sanity check; if we think a phenomenon X should never occur (hypothesis), we 
can test whether we can create situations where X may occur. If it does, then the method fails 
this sanity check. Another umbrella strategy is to create datasets with known characteristics or 
ground truth explanations. These could be purely synthetic constructions (e.g. generated tables 
with intentionally correlated features), semi-synthetic approaches (e.g. intentionally changing the 
labels on an image dataset), or taking slices of a real dataset (e.g. introduce intentional bias 
by only selecting real image, label pairs that are of outdoor environments). Note that these tests 
can only tell us if something is wrong; if a method passes a check, there may still be missing blindspots. 


Examples of Sanity Checks. One strategy is to come up statements of the form “if this 
explanation is working, then phenomenon X should not be occurring” and then try to create a 
situation in which phenomenon X occurs. If we succeed, then the sanity check fails. By asking about 
out-of-the-box phenomena, this strategy can reveal some surprising failure modes of explanation 
methods. 

For example, [Ade+20a] operates under the assumption that a faithful explanation should be a 
function of a model’s prediction. The hypothesis is that the explanation should significantly change 
when comparing a trained model to an untrained model (where prediction is random). They show 
that many existing methods fail to pass this sanity check (Figure 33.4). 

In another example, [Kin+19] hypothesize that a faithful explanation should be invariant to input 
transformations that do not affect model predictions or weights, such as constant shift of inputs (e.g., 
all inputs are added by 10). This hypothesis can be seen as testing both faithfulness and stability 
properties. Their work shows that some methods fail this sanity check. 

Adversarial attacks on explanations also fall into this category. For example, [GAZ19] shows that 
two perceptively indistinguishable inputs with the same predicted label can be assigned very different 
explanations. 


Examples using (semi-)Synthetic Datasets. Constructed datasets can also help score proper- 
ties of various methods. We use the work of [YK19] as an example. Here, the authors were interested 
in explanations with the properties of compactness and faithfulness: it should not identify features 
as important if they are not. To test for these properties, the authors generate images with known 
correlations. Specifically, they generate multiple datasets, each with a different rate of how often 
each particular foreground object co-occurs with each particular background (see Figure 33.3). Each 
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22 Figure 33.8: An example of computational evaluation using (semi-)synthetic datasets from [YK19]: foreground 
23 images (e.g. dogs, backpacks) are placed on different backgrounds (e.g. indoors, outdoors) to test what an 
24 erplanation is looking for. 

25 

26 

27 

28 dataset comes with two labels per image: for the object and the background. 

29 Now, the authors compare two models: one trained to classify objects and one trained to classify 
30 backgrounds (left, Figure 33.3). If a model is trained to classify objects and they all happen to 
31have the same background, the background should be less important than in a model trained to 
32classify backgrounds ([YK19] call this ‘Model Contrast Score’). They also checked that the model 
33 trained to predict backgrounds was not providing attributions to the foreground objects (see right 
34Figure 33.3). Other works using similar strategies include [Wil+20b; Gha+21; PMT18; KPT21; 
35 Yeh+19b; Kim+18b]. 

36 

37 Examples with Real Datasets. While more difficult, it is possible to at least partially check 
38for certain kinds of properties on real datasets that have no ground-truth. 

39 For example, suppose an explanation ranks features from most to least important. We want to 
40 determine if this ranking is faithful. Further, suppose we can assume that the features do not interact. 
41 Then, one can attempt to make the prediction just with the top-1 most important feature, just the 
42top-2 ranked features, etc. and observe if the change in prediction accuracy exhibits diminishing 
43returns as more features are added. (If the features do interact, this test will not work. For example, 
44if features A, B, C are the top-3 features, but C is only important if feature B is present, the test 
45 above would over-estimate the importance of the feature C.) 

46 Figure 33.5 shows an example of this kind of test [Gho+19]. Their method outputs a set of 
47 
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33.4. EVALUATION OF INTERPRETABLE MACHINE LEARNING MODELS 
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Figure 33.4: Interpretability methods (each row) and their explanations as we randomize layers starting from 
the logits, and cumulatively to the bottom layer (each column), in the context of image classification task. The 
rightmost column is showing a completely randomized network. Most methods output similar explanations 
for the first two columns; one predicts the bird, and the other predicts randomly. This sanity check tests the 
hypothesis that the explanation should significantly change (quantitatively and qualitatively) when comparing 
a trained model and an untrained model [Ade+ 20a]. 


image patches (e.g., a set of connected pixels) that correlates with the prediction. They add top-n 
image patches provided by the explanation one by one and observe the desired trend in accuracy. A 
similar experiment in reverse direction (i.e., deleting top-n most important image patches one by 
one) provides additional evidence. Similar experiments are also conducted in [FV17; RSG16a]. 

For example, in [DG17], authors define properties in plain English first: Smallest sufficient region 
(smallest region of the image that alone allows a confident classification) and Smallest destroying 
region (smallest region of the image that when removed, prevents a confident classification), followed 
by careful operationalization of these properties such that they become the objective for optimization. 
Then, separately, an evaluation metric of saliency is defined to be " the tightest rectangular crop 
that contains the entire salient region and to feed that rectangular region to the classifier to directly 
verify whether it is able to recognise the requested class". While the "rectangular" constraint may 
introduce artifacts, it is a neat trick to make evaluation possible. By checking expected behavior as 
described above, authors confirm that methods’ behavior on the real data is aligned with the defined 
property compared to baselines. 


Evaluating the Evaluations. As we have seen so far, there are many ways to formalize a given 
property and many empirical tests to determine whether a property is present. Each empirical test 
will have different qualities. As an illustration, in [Tom-+20], the authors ask whether popular saliency 
metrics give consistent results across literature. They tested whether different metrics for assessing 
the quality of saliency-based explanations (explanations that identify important pixels or regions in 
images) is evaluating similar properties. In other words, this work tests consistency and stability 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1088 


1 
2 
3 
: 4 
5 Ay 
6 
7 
8 
9 (aang outed aches] 
Adding top-rated patches Deleting top-rated patches 
10 
— Top-1 Top-5 Top-10 Top-1 Top-5 Top-10 Top-15 
11 
12 
v Q 
a a 

13 a Batali 
14 i 7 
15 
16 
i SSC SDC 
Ra —@— Most Important =E- Most Important 
18 100 © Least Important 100} © Least Important 

S —— Random Q —— Random 
19 Š 80 {Baseline performance E 80 {Baseline performance ne... 

fi o 
20 5 5 

8 8 60 
21 E E 
22 3 5 

£ È -ol 
23 a a 2 
24 C er ne: oe? (aT O. i229 4 oS 
25 Number of added concepts Number of deleted concepts 


26 
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ag Used with kind permission of Yarin Gal. Bottom row is from Figure 4 of [Gho+19]. One would expect that 
~ adding or deleting patches rated as most ‘relevant’ for an image classification would have a large effect on the 
=~ classification compared to patches not rated as important. 

30 


31 

32 

33 properties of metrics. They show many metrics are statistically unreliable and inconsistent. While 
34each metric may still have a particular use [Say+19], knowing this inconsistency exists helps us better 
35 understand the landscape and limitations of evaluation approaches. Developing good evaluations for 
36computational properties is an ongoing area of research. 

37 


3333.4.2 User Study-based Evaluation: Does the Method Help a User Perform a 


5 Target Task? 


41 User study-based evaluations measure whether an interpretable machine learning method helps 
42a human perform some task. This task could be the ultimate end-task of interest (e.g. does a 
43method help a doctor make better treatment decisions) or a synthetic task that mirrors contexts of 
44interest (e.g. a simplified situation with artificial diseases and symptoms). In both cases, rigorous 
45experimental design is critical to ensuring that the experiment measures what we want it to measure. 
46 Understanding experimental design for user studies is essential for research in interpretable machine 
47 
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33.4. EVALUATION OF INTERPRETABLE MACHINE LEARNING MODELS 


learning. 


33.4.2.1 User Studies in Real Contexts. 


The gold standard for testing whether an explanation is useful is to test it in the intended context: 
Do clinicians make better decisions with a certain kind of decision support? Do programmers debug 
code faster with a certain kind of explanation about model errors? Do product teams create more 
fair models for their businesses? A complete guide on how to design and conduct user studies is out 
of scope for this chapter; below we point out some basic considerations. 


33.4.2.2 Basic elements of user studies 


Performing a high-quality user study is a nuanced and non-trivial endeavor. There are many sources 
of bias, some obvious (e.g. learning and fatigue effects during a study) and some less obvious (e.g. 
participants willing to work with us are more optimistic about AI technology than those we could 
not recruit, or different participants may have different needs for cognition). 


Interface Design. The explanation must be presented to the user. Unlike the intrinsic difficulty 
of explaining a model (i.e., complex models are harder to explain than simple ones), the design of 
the interface is an extrinsic source of difficulty that can confound the experimental results. For 
example, it may be easier, in general, to scan a list of features ordered by importance rather than 
alphabetically. 

When we perform an evaluation with respect to an end-task, intrinsic and extrinsic difficulties 
can get conflated. Does one explanation type work better because it does a better job of explaining 
the complex system? Or does it work better simply because it was presented in a way that was 
easier for people to use? Especially if the goal is to test the difference between one explanation and 
another in the experiment, it is important that the interface for each is designed to tease out the 
effect from the explanations and their presentations. (Note that good presentations and visualization 
are an important but different object of study.) Moreover, if using the interface requires training, it 
is important to deliver the training to users in a way that is neutral in each testing condition. For 
example, how the end-task and goals of the study are described during training (e.g. with practice 
questions) will have a large impact on how users approach the task. 


Baselines. Simply the presence of an explanation may change the way in which people interact 
with an ML system. Thus, it is often important to consider how a human performs with no ML 
system, with an ML system and no explanation, with an ML system and a placebo explanation (an 
explanation that provides no information), and with an ML system and hand-crafted explanations 
(manually generated by humans who are presumably good communicators). 


Experimental Design and Hypothesis Testing. Independent and dependent variables, hy- 
potheses, and inclusion and exclusion criteria must be clearly defined prior to the start of the study. 
For example, suppose that one hypothesizes that a particular explanation will help a developer 
debug an image classifier. In this case, the independent variable would be a form of assistance: the 
particular explanation, competing explanation methods, and the baselines above. The dependent 
variable would be whether the developer can identify bugs. Inclusion and exclusion criteria might 
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include a requirement that the developer has sufficient experience training image classifiers (as 
determined by an initial survey, or a pre-test), demonstrates engagement (as measured by a base 
level of performance on practice rounds), and does not have prior experience with the particular 
explanation types (as determined by an initial survey). Other exclusion criteria could be removing 
outliers. For example, one could decide, in advance, to exclude data from any participant that takes 
an unusually long or short time to perform task as a proxy for engagement. 

As noted in Section 33.2, there are many decisions that go into any interpretable machine learning 
method, and each context is nuanced. Studies of the form “Does explanation X (computed via some 
iopipeline Y) help users in context Z compared to explanation X‘?” may not provide much insight 
1188s to why that particular explanation is better or worse—making it harder not only to iterate on 
12a particular explanation but also to generalize to other explanations or contexts. There are many 
i3factors of potential variation in the results, ranging from the properties of the explanation and its 
14presentation to the difficulty of the task. 

15 To reduce this variance, and to get more useful and generalizable insights, we can manipulate 
igSome factors of variation directly. For example, suppose the research question is whether complete 
17explanations are better than incomplete explanations in a particular context. One might write out 
ig hand-crafted explanations that are complete in what features they implicate, explanations in which 
igone important feature is missing, and explanations in which several important features are missing. 
20 Doing so ensures even coverage of the different experimental regimes of interest, which may not occur 
aif the explanations were simply output from a pipeline. As another example, one might intentionally 
22create an image classifier with known bugs, or simply pretend to have an image classifier that makes 
23certain predictions (as done in [Ade+20b]). These kinds of studies are called wizard-of-oz studies, 
24and they can help us more precisely uncover the science of why an explanation is useful (e.g. as done 
251n [Jac+21]). 

26 Once the independent and dependent variables, hypotheses, and participant criteria (including 
27 how the independent and dependent variables may be manipulated) are determined, the next step 
agis setting up the study design itself. Broadly speaking, randomization marginalizes over potential 
29confounds. For example, randomization in assigning subjects to tasks marginalizes the subject’s 
30 prior knowledge; randomization in the order of tasks marginalizes out learning effects. Matching and 
31repeated measures reduce variance. An example of matching would be asking the same subject to 
32perform the same end-task with two different explanations. An example of repeated measures would 
33 be asking the subject to perform the end-task for several different inputs. 

34 Other techniques for designing user studies include block randomized designs/Latin square designs 
35 that randomize the order of explanation types while keeping tasks associated with each explanation 
36 type grouped together. This can be used to marginalize the effects of learning and fatigue without 
37 too much context switching. Careful consideration should be given to what will be compared within 
3gSubjects and across subjects. Comparisons of task performance within subjects will have lower 
39 Variance but a potential bias from learning effects from the first task to the second. Comparisons 
agacross subjects will have higher variance and also potential bias from population shift during ex- 
41 perimental recruitment. Finally, each of these study designs, as well as the choice of independent 
wand dependent variables, will imply an appropriate significance test. It is essential to choose the 
a3right test and multiple hypothesis correction to avoid inflated significance values while retaining power. 
44 

45 Qualitative Studies. So far, we have described the standard approach for the design of a 
46 quantitative user study—one in which the dependent variable is numerically measured (e.g. time 
47 
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taken to correctly identify a bug, % bugs detected). While quantitative studies provide value by 
demonstrating that there is a consistent, quantifiable effect across many users, they usually do not 
tell us why a certain explanation worked. In contrast, qualitative studies, often performed with a 
“think-aloud” or other discussion-based protocol in which users expose their thought process as they 
perform the experiment, can help identify why a particular form of explanation seems to be useful or 
not. The experimenter can gain insights by hearing how the user was using the information, and 
depending on the protocol, can ask for clarifications. 

For example, suppose one is interested in how people use an example-based explanation to 
understand a video-game agent’s policy. The idea is to show a few video clips of an automated 
agent in the video game, and then ask the user what the agent might do in novel situations. In a 
think-aloud study, the user would perform this task while talking through how they are connecting 
the videos they have seen to the new situation. By hearing these thoughts, a researcher might not 
only gain deeper insight into how users make these connections—e.g. users might see the agent 
collect coins in one video and presume that the agent will always go after coins—but they might also 
identify surprising bugs: for example, a user might see the agent fall into a pit and attribute it to 
a one-off sloppy fingers, not internalizing that an automated agent might make that mistake every 
time. 

While a participant in a think-aloud study is typically more engaged in the study than they might 
be otherwise (because they are describing their thinking), knowing their thoughts can provide insight 
into the causal process between what information is being provided by the explanation and the 
action that the human user takes, ultimately helping advance the science of how people interact with 
machine-provided information. 


Pilot Studies: The above descriptions are just a very high-level overview of the many factors 
that must be designed properly for a high-quality evaluation. In practice, one does not typically get 
all of these right the first time. Small scale pilot studies are essential to checking factors such as 
whether participants attend to the provided information in unexpected ways or whether instructions 
are clear and well-designed. Modifying the experiments after iterative small scale pilot studies can 
save a lot of time and energy down the road. In these pilots, one should collect not only the usual 
information about users and the dependent variables, but also discuss with the participants how they 
approached the study tasks and whether any aspects of the study were confusing. These discussions 
will lead to insights and confidence that the study is testing what it is intended to test. The results 
from pilot studies should not be included in the final results. 

Finally, as the number of factors to test increases (e.g., baselines, independent variables), the study 
design becomes more complex and may require more participants and longer participation times to 
determine if the results are significant—which can in turn increase costs and effects of fatigue. Pilots, 
think-aloud studies, and careful thinking about what aspects of the evaluation require user studies 
and what can be completed computationally can all help distill down a user-based evaluation to the 
most important factors. 


33.4.2.3 User Studies in Synthetic Contexts 


It is not always appropriate or possible to test an interpretable machine learning method in the real 
context: for example, it would be unethical to test a prototype explanation system on patients each 
time one has a new way to convey information about a treatment recommendation. In such cases, we 
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might want to run an experiment in which clinicians perform a task on made-up patients, or in some 
analogous non-medical context where the participant pool is bigger and more affordable. Similarly, 
one might create a relatively accessible image classification debugging context where one can control 
the incorrect labels, distribution shifts, etc. (e.g. [Ade+20b]) and see what explanations help users 
detect problems in this simpler setting. The convenience and scalability of using a simpler setting 
could shed light on what properties of explanations are important generally (e.g., for debugging 
image classification). For example, we can test how different forms of explanation have different 
cognitive loads or how a particular property affects performance with a relatively large pool of 
io subjects (e.g., [Lag+19]). The same principles we outlined above for user studies in real contexts 
11continue to apply, but there are some important cautions. 


IO 100 IN ID Io Ie lo IN Ie 


Cautions regarding synthetic contexts: While user studies with synthetic contexts can be 
valuable for identifying scientific principles, one must be cautious. For example, experimental subjects 
in a synthetic high-stakes context may not treat the stakes of the problem as seriously, may be 
ierelatively unburdened with respect to distractions or other demands on their time and attention (e.g. 
17a quiet study environment vs. a chaotic hospital floor), and ignore important factors of the task (e.g., 
ig¢licking through to complete the task as quickly as possible). Moreover, small differences in task 
19 definition can have big effects: even the difference between asking users to simply perform a task 
2ọwith an explanation available vs. asking users to answer some questions about the explanation first, 
21 may create very different results as the latter forces the user to pay attention to the explanation and 
22the former does not. Priming users by giving them a specific scenario where they can put themselves 
23into a mindset could help. For example: “Imagine now you are an engineer at a company selling a 
24risk calculator. A deadline is approaching and your boss wants to make sure the product will work 
25 for a new client. Describe how you would use the following explanation.” 

26 

27 

2833.5 Discussion: How to Think about Interpretable Machine Learning 
29 

30Interpretable machine learning is a young, interdisciplinary field of study. As a result, consensus 
310n definitions, evaluation methods, and appropriate abstractions is still forming. The goal of this 
32section is to lay out a core set of principles about interpretable machine learning. While specifics in 
33the previous sections may change, the principles below will be durable. 

34 

35 There is no universal, mathematical definition of interpretability, and there never will be. Defining 


le [iS le Its 


36a downstream performance metric (and justifying it) for each context is a must. The information 
37that best communicates to the human what is needed to perform a task will necessarily vary: for 
38example, what a clinical expert needs to determine whether to try a new treatment policy is very 
39 different than what a person to determine how to get a denied loan approved. Similarly, methods to 
40 communicate characteristics of models built on pixel data may not be appropriate for communicating 
41 characteristics of models built on language data. We may hope to identify desired properties in 
42 explanations to maximize downstream task performance for different classes of end tasks—that is the 
43 grand challenge of interpretable machine learning—but there will never be one metric for all contexts. 
44 While this lack of a universal metric may feel disappointing, other areas of machine learning 
45 also lack universal metrics. For example, not only is it impossible to satisfy the many metrics on 
46 fairness at the same time [KMR16], but also in a particular situation, none may exactly match the 
47 
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desires of the stakeholders. Even in a standard classification setting, there are many metrics that 
correspond to making the predicted and true labels as close as possible. Does one care about overall 
accuracy? Precision? Recall? It is unlikely that one objective captures everything that is needed 
in one situation, much less across different contexts. Evaluation can still be rigorous as long as 
assumptions and requirements are made precise. 

What sets interpretable machine learning apart from other areas of machine learning, however, is 
that a large class of evaluations require human input. As a necessarily interdisciplinary area, rigorous 
work in interpretable machine learning requires not only knowledge of computation and statistics but 
also experimental design and user studies. 


Interpretability is only a part of the solution for fairness, calibrated trust, accountability, causality 
and other important problems. Learning models that are fair, safe, causal, or engender calibrated 
trust are all goals, whereas interpretability is one means toward that goal. 

In some cases, we don’t need interpretability. For example, if the goal can be fully formalized in 
mathematical terms (e.g., a regulatory requirement may mandate a model satisfy certain fairness 
metrics), we do not need any human input. If a model behaves as expected across an exhaustive 
set of pre-defined inputs, then it may be less important to understand how it produced its outputs. 
Similarly, if a model performs well across a variety of regimes, that might (appropriately) increase 
one’s trust in it; if it makes errors, that might (appropriately) decrease trust without an inspection 
of any of the system’s internals. 

In other cases, human input is needed to achieve the end-task. For example, while there is 
much work in the identification of causal models (see Chapter Main Chapter 36), under many 
circumstances, it is not possible to learn a model that is guaranteed to be causal from a dataset alone. 
Here, interpretability could assist the end-task of “Is the model causal?” by allowing a human to 
inspect the model’s prediction process. 

As another example, one could measure the safety of a clinical decision support system by tracking 
how often its recommendations causes harm to patients—and stop using the system if it causes too 
much harm. However, if we use this approach to safety, we will only discover that the system is 
unsafe after a significant number of patients have been harmed. Here, interpretability could support 
the end-task of safety by allowing clinical experts to inspect the model’s decision process for red flags 
prior to deployment. 

In general, complex contexts and end-tasks will require a constellation of methods (and people) to 
achieve them. For example, formalizing a complex notion such as accountability will require a broad 
collection of people—from policy makers and ethicists to corporations, engineers, and users—unifying 
vocabularies, exchanging domain knowledge, and identifying goals. Evaluating or monitoring it will 
involve various empirical measures of quality and insights from interpretability. 


Interpretability is not about understanding everything about the model; it is about understanding 
enough to do the end-task. The ultimate measure of an interpretable machine learning method is 
whether it helps the user perform their end-task. Suppose the end-task is to fix an overheating laptop. 
An explanation that lists the likely sources of heat is probably sufficient to address the issue, even 
if one does not know the chemical properties of its components. On the other hand, if the laptop 
keeps freezing up, knowing about the sources of heat may not be the right information. Importantly, 
both end-tasks have clear downstream performance metrics: we can observe whether the information 
helped the user perform actions that make the laptop overheat or freeze up less. 
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As another example, consider AlphaGo, Google DeepMind’s AI Go player that beat the human 
world champion, Lee SeDol. The model is so complex that one cannot fully understand its decision 
process, including surprising moves like its famous move 37|Met16]. That said, partial probes (e.g., 
does AlphaGo believe the same move would have made a different impact if it was made earlier but 
similar position in the game) might still help a Go expert gain insights on the rationale for the move 
in the context of what they already know about the game. 


IO 100 IN ID Io Ie low IN IR 


Relatedly, interpretability is distinct from full transparency into the model or knowing the model’s 
1ocode. Staring at the weights of every neuron in a large network is likely to be as effective as taking 
110ne’s laptop apart to understand a bug in your code. There are many good reasons for open source 
12projects and models, but open source code itself may or may not be sufficient for a user to accomplish 
13 their end-task. For example, a typical user will not be able to reason through 1004 lines of parameters 
i4despite having all the pieces available. 


That said, any partial view of a model is, necessarily, only a partial view; it does not tell the full 
izstory. While we just argued that many end-tasks do not require knowing everything about a model, 
igwe also must acknowledge that a partial view does not convey the full model. For example, the set of 
i9features needed to change a loan decision may be the right partial view for a denied applicant, but 
20 convey nothing about whether the model is discriminatory. Any probe will only return what it is 
21 designed to compute (e.g., an approximation of a complex function with a simpler one). Different 
22 probes may be able to reveal different properties at different levels of quality. Incorrectly believing 
23 the partial view is the full story could result in incorrect insights. 


le la 


24 

95 Partial views can lack stability and enable attacks. Relatedly, any explanation that reveals only 
ge certain parts of a model can lack stability (e.g. see [AMJ18a]) and can be more easily attacked 
a7(e.g. see [Yeh+19a; GAZ19; Dom+19; Sla+20]). Especially when models are overparameterized 
ggsuch as neural networks, it is possible to learn models whose explanations say one thing (e.g. a 
29 feature is not important, according to some formalization of feature importance) while the model 
30does another (e.g. uses the prohibited feature). Joint training can also exacerbate the issue, as it 
31 allows the model to learn boundaries that pass some partial-view test while in reality violating the 
32underlying constraint. Other adversarial approaches can work on the input, minimally perturbing 
33it to change the explanation’s partial view while keeping the prediction constant or to change the 
34prediction while keeping the explanation constant. 

35 These concerns highlight an important open area: We need to improve ways to endow explanations 
36With the property of translucence, that is, explanations that communicate what they can and cannot 
37Say about the model. Translucence is important because misinterpreted explanations that happen to 
3g favor a user’s views create false basis for trust. 

39 

40 Trade-offs between inherently interpretable models and performance often do not exist; partial views 
aican help when they do. 

42 While some have claimed that there exists an inherent trade-off between using an inherently- 
43interpretable model and performance (defined as a model’s performance on some test data), this 
4atrade-off does not always exist in practice for several reasons[Rud19]. 

45 First, in many cases, the data can be surprisingly well-fit by a fairly simple model (due to high 
4g noise, for example) or a model that can be decomposed into interpretable parts. One can often find a 
47 
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combination of architecture, regularizer, and optimizer that produces inherently interpretable models 
with performance comparable to, or sometimes even better than, blackbox approaches [Wan+17a; 
LCG12; Car+15; Let+15b; UR16; FHDV20; KRS14]. In fact, interpretability and performance can 
be synergistic: methods for encoding a preference for simpler models (e.g., L1 regularizer for sparsity 
property) were initially developed to increase performance and avoid overfitting, and interpretable 
models are often more robust [RDV 18]. 

Second, a narrow focus on the trade-off between using inherently interpretable models and a 
predefined metric of performance, as usually measured on a validation set, overlooks a broader issue: 
that predefined metric of performance may not tell the full story about the quality of the model. For 
example, using an inherently interpretable model may enable a person to realize that a prediction is 
based on confounding, not causation—or other ways it might fail in deployment. In this way, one 
might get better performance with an inherently interpretable model in practice even if a blackbox 
appears to have better performance numbers in validation. An inherently interpretable model may 
also enable better human+model teaming by allowing the human user to step in and override the 
system appropriately. 


Human factors are essential. All machine learning systems ultimately connect to broader socio- 
technical contexts. However, in many cases, many aspects of model construction and optimization can 
be performed in a purely computational setting: there are techniques to check for appropriate model 
capacity, techniques for tuning a gradient descent or convex optimization. In contrast, interpretable 
machine learning must consider human factors from the beginning: there is no point optimizing 
an explanation to have various properties if it still fails to improve the user’s performance on the 
end-task. 

Over-reliance. Just because an explanation is present, does not mean that the user will analytically 
and reasonably incorporate the information provided into their ultimate decision-making task. The 
presence of any explanation can increase a user’s trust in the model, exacerbating the general issue 
of over-trust in human+ML teams. Recent studies have found that even data scientists over-trust 
explanations in unintended ways [Kau+20]; their excitement about the tool led them to take it 
at face-value rather than dig deeper. [LM20] reports a similar finding, noting that inaccurate but 
evocative presentations can create a feeling of comprehension. 

Over-reliance can be combated with explicit measures to force the user to engage analytically 
and skeptically with the information in the explanation. For example, one could ask the user to 
submit their decision first and only then show the recommendation and accompanying explanation 
to pique their interest in why their choice and the recommendation might disagree (and prompting 
whether they want to change their choice). Another option is to ask the user some basic questions 
about the explanation prior to submitting their decision to force them to look at the explanation 
carefully. Yet another option is to provide only the relevant information (the explanation) without 
the recommendation, forcing the user to synthesize the additional information on their own. However, 
in all these cases, there is a delicate balance: users will often be amenable to expending additional 
cognitive effort if they can see it achieves better results, but if they feel the effort is too much, they 
may start ignoring the information entirely. 


Potential for Misuse. A malicious version of over-reliance is when explanations are used to manip- 


ulate a user rather than facilitating the user’s end-task. Further, users may report that they like 
explanations that are simple, require little cognitive effort, etc. even when those explanations do not 
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Original Image 


Figure 33.6: (Potential) perception issues: an explanation from a trained network (left) is visually indistin- 
guishable to humans from one from an untrained network (right)—even if they are not exactly identical. 
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jg Help them perform their end-task. As creators of interpretable machine learning methods, one must 
47 De on alert to ensure that the explanations help the user achieve what they want to (ideally in a way 
yg that they also like). 


19 


99  Misunderstandings from a lack of understanding of machine learning. Even when correctly engaged, 


21 users in different contexts will have different levels of knowledge about machine learning. For example, 
9 Hot everyone may understand concepts such as additive factors or Shapley values [Shal6]. Users may 
93 also attribute more understanding to a model than it actually has. For example, if they see a set of 
34 Pixels highlighted around a beak, or a set of topic model terms about a disease, they may mistakenly 
9, believe that the machine learning model has some notion of concepts that matches theirs, when the 


gg truth might be quite different. 


Related: Perception issues in image explanations. The nature of our visual processing system 
99 adds another layer of nuance when it comes to interpreting and misinterpreting explanations. In 
30 Figure 33.6, two explanations (in terms of important pixels in a bird image) seem to communicate a 
31 Similar message; for most people, both explanations seem to suggest that the belly and cheek of the 
3) bird are the important parts for this prediction. However, one of them is generated from a trained 
33 network (left), but the other one is from a network that returns random predictions (right). While 
34 the two saliency maps aren’t identical to machines, they look similar because humans don’t parse an 
o- image as pixel values, but as whole: they see a bird in both pictures. 


35 
m Another common issue with pixel-based explanations is that explanation creators often multiply the 
37 Original image with an importance "mask" (black and clear saliency mask, where black pixel represents 


3g 20 importance and a clear pixel represents maximum importance), introducing the arbitrary artifact 
3gthat black objects never appear important [Smi+17]. In addition, this binary mask is produced 
40 PY clipping important pixels in a certain percentile (e.g., only taking 99—th percentile), which can 
4, so introduce another artifact [Sun+19c]. The balancing act between artifacts introduced by visu- 
4 lization for the ease of understanding and faithfully representing the explanation remains a challenge. 
P Together, all of these points on human factors emphasize what we said from the start: we cannot 
45 dvorce the study and practice of interpretable machine learning from its intended socio-technical 


46 context. 


47 
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34 Decision making under uncertainty 


34.1 Bayesian decision theory 


Bayesian inference provides the optimal way to update our beliefs about hidden quantities H given 
observed data X = x by computing the posterior p(H|x). However, at the end of the day, we 
need to turn our beliefs into actions that we can perform in the world. How can we decide which 
action is best? This is where Bayesian decision theory comes in. In this section, we give a brief 
introduction. For more details, see e.g., [DeG70; KWW22]. 


34.1.1 Basics 


In decision theory, we assume the decision maker, or agent, has a set of possible actions, A, to 
choose from. Each of these actions has costs and benefits, which will depend on the underlying state 
of nature h € H. We can encode this information into a loss function (h,a), that specifies the 
loss we incur if we take action a € A when the state of nature is h € H. 

Once we have specified the loss function, we can compute the posterior expected loss or risk 
for each possible action: 


R(a\a) È Epira) Eh, a)] = D> Kh, a)p(h|x) (34.1) 
hEH 


The optimal policy (also called the Bayes estimator) specifies what action to take for each 
possible observation so as to minimize the risk: 


n“ (x) = a Ep(h|x) KCh, @)] (34.2) 


An alternative, but equivalent, way of stating this result is as follows. Let us define a utility 
function U(h,a) to be the desirability of each possible action in each possible state. If we set 
U(h,a) = —€(h, a), then the optimal policy is as follows: 


m* (a) = argmax Ep [U (h, a)] (34.3) 
acA 


This is called the maximum expected utility principle. 


1100 


1 

2 34.1.2 Classification 

3 Suppose the states of nature correspond to class labels, so H = Y = {1,...,C}. Furthermore, 

4 suppose the actions also correspond to class labels, so A = YV. In this setting, a very commonly used 

5 PA eos 

= loss function is the zero-one loss fo) (y*, 7), defined as follows: 

6 

T ĝ=0 ĝ=1 

8 y*=0] 0 1 (34.4) 

2 y*=1| 1 0 

10 

11We can write this more concisely as follows: 

12 

13 foi (y", 9) =1(y* #9) (34.5) 

14 

15 In this case, the posterior expected loss is 

16 

17 RGlx) =O Ay" |e) =1- ply" = glx) (34.6) 
8 

įg Hence the action that minimizes the expected loss is to choose the most probable label: 

20 

5, (ae) = argmax p(y|x) (34.7) 

21 

= yey 

22 


23This corresponds to the mode of the posterior distribution, also known as the maximum a 
24posteriori or MAP estimate. 

25 We can generalize the loss function to associate different costs for false positives and false negatives. 
26 We can also allow for a “reject action’, in which the decision maker abstains from classifying when 
27it is not sufficiently confident. This is called selective prediction; see Section 19.3.3 for details. 
28 

34.1.3 Regression 

31 Now suppose the hidden state of nature is a scalar h € R, and the corresponding action is also a 
32scalar, y € R. The most common loss for continuous states and actions is the Z2 loss, also called 
33squared error or quadratic loss, which is defined as follows: 

34 

3 fa(h,y) = (h -y (34.8) 
36 

37m this case, the risk is given by 


R(y|æ) =E [(h — y)?|a] = E [h?|æ] — 2yE [hla] + y? (34.9) 


“The optimal action must satisfy the condition that the derivative of the risk (at that point) is zero 
mc explained in Chapter 6). Hence the optimal action is to pick the posterior mean: 
a øg 
44 a = —2Æ [hx] +2y =0 => r(x) = E [Ala] = Jh ohjæjan (34.10) 
45 
46 This is often called the minimum mean squared error estimate or MMSE estimate. 
47 
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Figure 84.1: Spectograms for three different spoken sentences. The z-axis shows progression of time and 
the y-axis shows different frequency bands. The energy of the signal in different bands is shown as intensity 
in grayscale values with progression of time. (A) and (B) show spectrograms of the same sentence “How 
to recognize speech with this new display” spoken by two different speakers, male and female. Although the 
frequency characterization is similar, the formant frequencies are much more clearly defined in the speech 
of the female speaker. (C) shows the spectrogram of the utterance “How to wreck a nice beach with this 
nudist play” spoken by the same female speaker as in (B). (A) and (B) are not identical even though they 
are composed of the same words. (B) and (C) are similar to each other even though they are not the same 
sentences. From Figure 1.2 of [Gan07]. Used with kind permission of Madhavi Ganapathiraju. 


34.1.4 Structured prediction 


In some problems, such as natural language processing or computer vision, the desired action is to 
return an output object y € V, such as a set of labels or body poses, that not only is probable given 
the input æ, but is also internally consistent. For example, suppose x is a sequence of phonemes and 
y is a sequence of words. Although xz might sound more like y = “How to wreck a nice beach” on a 
word-by-word basis, if we take the sequence of words into account then we may find (under a language 
model prior) that y = “How to recognize speech” is more likely overall. (See Figure 34.1.) We can 
capture this kind of dependency amongst outputs, given inputs, using a structured prediction 
model, such as a conditional random field (see Section 4.4). 

In addition to modeling dependencies in p(y|a), we may prefer certain action choices y, which we 
capture in the loss function ¢(y, y). For example, referring to Figure 34.1, we may be reluctant to 
assume the user said y,="nudist” at step t unless we are very confident of this prediction, since the 
cost of mis-categorizing this word may be higher than for other words. 

Given a loss function, we can pick the optimal action using minimum Bayes risk decoding: 


y= min % v(ylx)e(y, d) (34.11) 
YE 
vey 


We can approximate the expectation empirically by sampling M solutions y™ ~ p(y|a) from the 
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1 

2 posterior predictive distribution. (Ideally these are diverse from each other.) We use the same set of 
3 M samples to approximate the minimization to get 

4 . 7 , TE 

>= ĝ~ min XO palella, y’) (34.12) 
3 yt, ic{1,... M}. 

6 E{1,...,M} 

I This is called empirical MBR [Pre+17a], who applied it to computer vision problems. A similar 
8 approach was adopted in [Fre+22], who applied it to neural machine translation. 

9 


B45 Fairness 


12 Models trained with ML are increasingly being used to high-stakes applications, such as deciding 
13whether someone should be released from prison or not, etc. In such applications, it is important 
14that we focus not only on accuracy, but also on fairness. A variety of definitions for what is meant 
15by fairness have been proposed (see e.g., [VR18]), many of which entail conflicting goals [Kle18]. 
16 Below we mention a few common definitions, which can all be interpreted decision theoretically. 

7 We consider a binary classification problem with true label Y, predicted label Y and sensitive 


igattribute S (such as gender or race). The concept of equal opportunity requires equal true 


19 positive rates across subgroups, i.e., p 1Y =1,S = 0) = p(Ŷ = 1|[Y = 1, S = 1). The concept 
20o0f equal odds requires equal true positive rates across subgroups, and also equal false positive rates 
2lacross subgroups, i.e., pÝ 1]Y = 0,5 =0) p 1|Y = 0, S = 1). The concept of statistical 
22 parity requires positive predictions to be unaffected by the value of the protected attribute, regardless 


230f the true label, i.e., p(Y = 1]S = 0) = p(Y|S = 1). 

24 A simple and generic way to achieve these fairness goals is to use constrained Bayesian optimization 
25(Section 6.8), in which we maximize for accuracy subject to achieving one or more of the above goals. 
26 See [Per+21] for details. 

27 

2 


5 34.2 Decision (influence) diagrams 


39When dealing with structured multi-stage decision problems, it is useful to use a graphical notation 
31 called an influence diagram [HM81; KMO8], also called a decision diagram. This extends directed 
32 probabilistic graphical models (Chapter 4) by adding decision nodes (also called action nodes), 
33represented by rectangles, and utility nodes (also called value nodes), represented by diamonds. 


34The original random variables are called chance nodes, and are represented by ovals, as usual. 
35 


34.2.1 Example: oil wildcatter 


38 As an example (from [Rai68]), consider creating a model for the decision problem faced by an oil 
39 wildcatter”, which is a person who drills wildcat wells, which are exploration wells drilled in areas 
40not known to be oil fields. 

41 Suppose you have to decide whether to drill an oil well or not at a given location. You have two 
42 possible actions: d = 1 means drill, d = 0 means don’t drill. You assume there are 3 states of nature: 
430 = 0 means the well is dry, o = 1 means it is wet (has some oil), and o = 2 means it is soaking (has 
44a lot of oil). We can represent this as a decision diagram as shown in Figure 34.2(a). 

45 Suppose your prior beliefs are p(o) = [0.5, 0.3, 0.2], and your utility function U (d,o) is specified by 
46 the following table: 

47 
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34.2. DECISION (INFLUENCE) DIAGRAMS 


Sound Cost 


oil Drill 


> 


(a) (b) (c) 


Utility 


Figure 34.2: Influence diagrams for the oil wild catter problem. Ovals are random variables (chance nodes), 
squares are decision (action) nodes, diamonds are utility (value) nodes. (a) Basic model. (b) An extension in 
which we have an information arc from the Sound chance node to the Drill decision node. (c) An extension 
in which we get to decide whether to perform a test or not, as well as whether to drill or not. 


| o=0 o=1 o=2 
d=0 0 0 0 
d=1 | -70 50 200 
We see that if you don’t drill, you incur no costs, but also make no money. If you drill a dry well, 
you lose $70; if you drill a wet well, you gain $50; and if you drill a soaking well, you gain $200. 
What action should you take if you have no information beyond your prior knowledge? Your prior 
expected utility for taking action d is 


2 


EU(d) = $ ` p(0)U (d, 0) (34.13) 


o=0 
We find EU(d = 0) = 0 and EU(d = 1) = 20 and hence the maximum expected utility is 
MEU = max{EU(d = 0), EU(d = 1)} = max{0, 20} = 20 (34.14) 


Thus the optimal action is to drill, d* = 1. 


34.2.2 Information arcs 


Now let us consider a slight extension to the model, in which you have access to a measurement 
(called a “sounding”), which is a noisy indicator about the state of the oil well. Hence we add an 
O — S arc to the model. In addition, we assume that the outcome of the sounding test will be 
available before we decide whether to drill or not; hence we add an information arc from S to D. 
This is illustrated in Figure 34.2(b). Note that the utility depends on the action and the true state 
of the world, but not the measurement. 

We assume the sounding variable can be in one of 3 states: s = 0 is a diffuse reflection pattern, 
suggesting no oil; s = 1 is an open reflection pattern, suggesting some oil; and s = 2 is a closed 
reflection pattern, indicating lots of oil. Since S is caused by O, we add an O —> S arc to our model. 
Let us model the reliability of our sensor using the following conditional distribution for p(.S|O): 
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| s=0 s=1 s=2 

o=0 0.6 0.3 0.1 

o=1 0.3 0.4 0.3 
= 0.1 0.4 0.5 


Suppose the sounding observation is s. The posterior expected utility of performing action d is 


2 


EU(d|s) = X` p(o|s)U(o, d) (34.15) 
o=0 


KO lœ IN ID low e lo IN Ie 
© 
N 


z 


e need to compute this for each possible observation, s € {0,1,2}, and each possible action, 
d € {0,1}. If s = 0, we find the posterior over the oil state is p(o|s = 0) = [0.732, 0.219, 0.049], 
and hence EU(d = O|s = 0) = 0 and EU(d = 1ļ|s = 0) = —30.5. If s = 1, we similarly find 
EU(d = O|s = 1) = 0 and EU(d = 1|s = 1) = 32.9. If s = 2, we find EU(d = O|s = 2) = 0 and 
EU(d = 1|s = 2) = 87.5. Hence the optimal policy d*(s) is as follows: if s = 0, choose d = 0 and get 
$0; if s = 1, choose d = 1 and get $32.9; and if s = 2, choose d = 1 and get $87.5. 

The maximum expected utility of the wildcatter, before seeing the experimental sounding, can be 
computed using 


MEU = X ` p(s)EU(d*(s)|s) (34.16) 


N JN 
IF le le Sls lale le Is IE l5 


22where prior marginal on the outcome of the test is p(s) = $, p(o)p(s|o) = [0.41, 0.35, 0.24]. Hence 
23the MEU is 
24 


9, MEU = 0.41 x 0+ 0.35 x 32.9 + 0.24 x 87.5 = 32.2 (34.17) 


26 These numbers can be summarized in the decision tree shown in Figure 34.3. 
27 


2834.2.3 Value of information 
29 


30 Now suppose you can choose whether to do the test or not. This can be modelled as shown in 
31 Figure 34.2(c), where we add a new test node T. If T = 1, we do the test, and S can enter states 
32{0, 1,2}, determined by O, exactly as above. If T = 0, we don’t do the test, and S enters a special 
33 unknown state. There is also some cost associated with performing the test. 

34 Is it worth doing the test? This depends on how much our MEU changes if we know the outcome of 
35 the test (namely the state of S). If you don’t do the test, we have MEU = 20 from Equation (34.14). 
361f you do the test, you have MEU = 32.2 from Equation (34.17). So the improvement in utility if 
37you do the test (and act optimally on its outcome) is $12.2. This is called the value of perfect 
33information (VPI). So we should do the test as long as it costs less than $12.2. 

39 In terms of graphical models, the VPI of a variable S can be determined by computing the MEU 
aofor the base influence diagram, G, in Figure 34.2(b), and then computing the MEU for the same 
a1influence diagram where we add information arcs from S to the action node, and then computing the 
a2 difference. In other words, 


& VPI = MEU(G + S > D) — MEU(G) (34.18) 


45 where D is the decision node and S is the variable we are measuring. This will tell us whether it is 
46worth adding obtaining measurement S. 
47 
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34.3. A/B TESTING 


-30.5 
0.41 

0 

32.9 
0.35 

0 

87.5 
0.24 

0 


Figure 34.3: Decision tree for the oil wildcatter problem. Black circles are chance variables, black squares are 
decision nodes, diamonds are the resulting utilities. Green leaf nodes have higher utility than red leaf nodes. 


34.2.4 Computing the optimal policy 


In general, given an influence diagram, we can compute the optimal policy automatically by modifiying 
the variable elimination algorithm (Section 9.4), as explained in [LN01; KM08]. The basic idea is to 
work backwards from the final action, computing the optimal decision at each step, assuming all 
following actions are chosen optimally. When the influence diagram has a simple chain structure, 
as in a Markov decision process (Section 34.5), the result is equivalent to Bellman’s equation 
(Section 34.5.5). 


34.3 A/B testing 


Suppose you are trying to decide which version of a product is likely to sell more, or which version of 
a drug is likely to work better. Let us call the versions you are choosing between A and B; sometimes 
version A is called the control, and version B is called the treatment. (Sometimes the different 
actions are called “arms”.) 

A very common approach to such problems is to use an A/B test, in which you try both actions 
out for a while, by randomly assigning a different action to different subsets of the population, and 
then you measure the resulting accumulated reward from each action, and you pick the winner. 
(This is sometimes called a “test and roll” approach, since you test which method is best, and then 
roll it out for the rest of the population.) 

A key problem in A/B testing is to come up with a decision rule, or policy, for deciding which 
action is best, after obtaining potentially noisy results during the test phase. Another problem is 
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to choose how many people to assign to the treatment, nı, and how many to the control, no. The 
fundamental tradeoff is that using larger values of nı and no will help you collect more data and 
hence be more confident in picking the best action, but this incurs an opportunity cost, because 
the testing phase involves performing actions that may not result in the highest reward. (This is 
an example of the exploration-exploitation tradeoff, which we discuss more in Section 34.4.3.) In 
this section, we give a simple Bayesian decision theoretic analysis of this problem, following the 
presentation of [FB19].! More details on A/B testing can be found in [KT X20]. 


IO 100 IN ID Jo Ie low IN Ie 


1034.3.1 A Bayesian approach 
1 


12 We assume the i’th reward for action j is given by Yj; ~ N (uj, o?) fori =1:n,; and j = 0:1, where 
13J = 0 corresponds to the control (action A), j = 1 corresponds to the treatment (action B), and ny is 
athe number of samples you collect from group j. The parameters pj are the expected reward for 
15 action j; our goal is to estimate these parameters. (For simplicity, we assume the o? are known.) 
We will adopt a Bayesian approach, which is well suited to sequential decision problems. For 
simplicity, we will use Gaussian priors for the unknowns, u; ~ N (mj, TÊ), where m; is the prior 
mean reward for action j, and 7; is our confidence in this prior. We assume the prior parameters are 
known. (In practice we can use an empirical Bayes approach, as we discuss in Section 34.3.2.) 


IS Is le IS ls 


2134.3.1.1 Optimal policy 
22 
93 Initially we assume the sample size of the experiment (i.e. the values nı for the treatment and no for 


4the control) are known. Our goal is to compute the optimal policy or decision rule 7(y;, yo), which 
95 Specifies which action to deploy, where yj = (yij,---;Yn;,j) is the data from action j. 

96 The optimal policy is simple: choose the action with the greater expected posterior expected 
a7 reward: 


28 

= 1 ifE >E 

29 a*(yi,y) — 4! EE lala] 2 E [nolu] (34.19) 
30 0 if E[ui|y1] < E [uo]yo] 

31 


32All that remains is to compute the posterior. over the unknown parameters, uj. Applying Bayes’ 
33rule for Gaussians (Equation (2.82)), we find that the corresponding posterior is given by 


34 


a5 P(Mylyy nj) =N (uyl 5,75) (34.20) 
36 1/ Pj = nj /05 + 1/77 (34.21) 
37 > > = 

a iy | Fy = ngy;/o5 + mjr? (34.22) 


39We see that the posterior precision (inverse variance) is a weighted sum of the prior precision plus Nj 
2 units of measurement precision. We also see that the posterior precision weighted mean is a sum of 
4l the prior precision weighted mean and the measurement precision weighted mean. 


42 Given the posterior, we can plug ff; into Equation (34.19). In the fully symmetric case, where 


Bni = no, Mi = Mo = M, Ti = To = T, and g1 = co = o, we find that the optimal policy is to simply 
44 


451, For a similar set of results in the time-discounted setting, see https://chris-said.io/2020/01/10/ 
46 optimizing-sample-sizes-in-ab-testing-part-I. 
47 
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“pick the winner”, 


which is the arm with higher empirical performance: 


o2” T 


X m y m y = _ 
™* (y1, Yo) =1(5 ps + z) =I (Jı > o) (34.23) 


However, when the problem is asymmetric, we need to take into account the different sample sizes 
and/or different prior beliefs. 


34.3.1.2 Optimal sample size 


We now discuss how to compute the optimal sample size for each arm of the experiment, i.e, the 


values no and nı. 


testing phase, 


We assume the total population size is N, and we cannot reuse people from the 


The prior expected reward in the testing phase is given by 


e je je je je Je j= Je je 1© lœ IN Im o [A lw N e 
& IS la la le le lè le ls 


Is | 


© [Rrest] = Nomo + nımı (34.24) 


The expected reward in the roll phase depends on the decision rule 7(y1, yo) that we use: 


gy [Ronl= f f ff N -m — no) (wun, yo) + (= an v)u) (34.25) 
21 Hiv How Y1 4 Yo 
= x p(Yo| Ho )P(Ys|H1)P(Ho P(e dyodys duodu: (34.26) 
23 
24 For 7 = 7* one can show that this equals 
25 
> ` : e e 
26 “4 [Rron] 4 Te [Rron] = (N =ni no) (mı + eP(-) + vo(=)) (34.27) 
27 
28 where ¢ is the Gaussian pdf, ® is the Gaussian cdf, e = mo — mı and 
29 
30 4 4 
oT Ti To 

— 34.28 
7 yz Fom | T+ ogno vee 
33 In the fully symmetric case, Equation (34.27) simplifies to 
34 
35 V2r?2 
36 l [Rron] = (N — 2n)m + (N — 2n) ; (34.29) 
31 = Vr 2r? + 202 
EA R 
39 b 
40 This has an intuitive interpretation. The first term, Ra, is the prior reward we expect to get before 
41 we learn anything about the arms. The second term, Rẹ, is the reward we expect to see by virtue of 
42 picking the optimal action to deploy. 
43 Let us we write Rẹ = (N — 2n) Ri, where R; is the incremental gain. We see that the incremental 
44 gain increases with n, because we are more likely to pick the correct action with a larger sample 
45 size; however, this gain can only be accrued for a smaller number of people, as shown by the N — 2n 
46 prefactor. (This is a consequence of the explore-exploit tradeoff.) 
47 
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1 
2 The total expected reward is given by adding Equation (34.24) and Equation (34.29): 
3 
4 f v27? 
5 L [R] = E [Retest] + E [Rron] = Nm + (N — 2n) (34.30) 
è VT 272 + 20? 
7 
g (The equation for the non-symmetric case is given in [FB19].) 
9 We can maximize the expected reward in Equation (34.30) to find the optimal sample size for the 
19 testing phase, which (from symmetry) satisfies nj = n3 = n*, and from 7%E [R] = 0 satisfies 
= 2 
42 * N 21 [32 BB d 

= H N— 34.31 
a A (Fe 4° = "97 eet) 
14 
15where u? = = Thus we see that the optimal sample size n* increases as the observation noise 7 
16increases, since we need to collect more data to be confident of the right decision. However, the 
17optimal sample size decreases with 7, since a prior belief that the effect size 6 = 4, — uo will be large 
18implies we expect to need less data to reach a confident conclusion. 
19 
234.3.1.3 Regret 
21 


92 Given a policy, it is natural to wonder how good it is. We define the regret of a policy to be the 
23 difference between the expected reward given perfect information (PI) about the true best action 
94and the expected reward due to our policy. Minimizing regret is equivalent to making the expected 
gsreward of our policy equal to the best possible reward (which may be high or low, depending on the 
26 problem). 

97 An oracle with perfect information about which Hj is bigger would pick the highest scoring action, 
2gand hence get an expected reward of NE [max(j:1, u2)]. Since we assume uj ~ N’(m, 77), we have 


29 


m 
30 E[R|PI]=N (m + =) (34.32) 
31 vT 

m herefore the regret from the optimal policy is given by 


34 


T F 


35 E[R|PI] — (E [Rrest|7*] + E[Rrou|n*]) = N 1 


4. Qn*7? 
36 VT [2 1 o2 24 02 
36 ial VTAT? + = 


37 


(34.33) 


38 One can show that the regret is O(V N), which is optimal for this problem when using a time horizon 
3°(population size) of N [AG13]. 
40 


534.3.1.4 Expected error rate 


43 Sometimes the goal is posed as best arm identification, which means identifying whether u1 > po 
44or not. That is, if we define 6 = uı — uo, we want to know if 6 > 0 or 6 < 0. This is naturally 
45 phrased as a hypothesis test. However, this is arguably the wrong objective, since it is usually 
46not worth spending money on collecting a large sample size to be confideny that 6 > 0 (say) if the 
47 
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34.3. A/B TESTING 


magnitude of 6 is small. Instead, it makes more sense to optimize total expected reward, using the 
method in Section 34.3.1.1. 

Nevertheless, we may want to know the probability that we have picked the wrong arm if we use 
the policy from Section 34.3.1.1. In the symmetric case, this is given by the following: 


Hı — Ho 
Pr(m(y1, Yo) = Lu < Ho) = Pr(¥i — Yo > Ola < uo) =1-@ — (34.34) 
ON ni + no 
The above expression assumed that u; are known. Since they are not known, we can com- 
pute the expected error rate using E[Pr(m(y1, yo) = L|u1ı < o)|. By symmetry, the quantity 


2 [Pr(n(y1, Yo) = Oluı > uo)] is the same. One can show that both quantities are given by 


1 1 2 
Prob. error = — — — arctan va ae (34.35) 
4 Qn o nı + no 


As expected, the error rate decreases with the sample size nı and no, increases with observation noise 
cg, and decreases with variance of the effect size 7. Thus a policy that minimizes the classification 
error will also maximize expected reward, but it may pick an overly large sample size, since it does 
not take into account the magnitude of ô. 


34.3.2 Example 


In this section, we give a simple example of the above framework. Suppose our goal is to do website 
testing, where have two different versions of a webpage that we want to compare in terms of their 
click through rate. The observed data is now binary, y;; ~ Ber(j1;), so it is natural to use a Beta 
prior, pj ~ Beta(a, 8) (see Section 3.2.1). However, in this case the optimal sample size and decision 
rule is harder to compute (see [FB19; Sta+17]| for details). As a simple approximation, we can assume 
Tij ~ N (uj, o°), where pj ~ N(m,7”), m= sty, T = arpar, and o? = m(1—m). 

To set the Gaussian prior, [FB19] used empirical data from about 2000 prior A/B tests. For each 
test, they observed the number of times the page was served with each of the two variations, as well 
as the total number of times a user clicked on each version. Given this data, they used a hierarchical 
Bayesian model to infer uj ~ M (m = 0.68,7 = 0.03). This prior implies that the expected effect size 
is quite small, E [|1 — ol] = 0.023. (This is consistent with the results in [Aze+20], who found that 
most changes made to the Microsoft Bing EXP platform had negligible effect, although there were 
occasionally some “big hits”.) 

With this prior, and assuming a population of N = 100,000, Equation (34.31) says that the optimal 
number of trials to run is nj = nj = 2284. The expected reward (number of clicks or conversions) 
in the testing phase is E [Rest] = 3106, and in the deployment phase E [Rron] = 66430, for a total 
reward of 69536. The expected error rate is 10%. 

In Figure 34.4a, we plot the expected reward vs the size of the test phase n. We see that the 
reward increases sharply with n to the global maximum at n* = 2284, and then drops off more slowly. 
This indicates that it is better to have a slightly larger test than one that is too small by the same 
amount. (However, when using a heavy tailed model, [Aze+20] finds that it is better to do lots of 
smaller tests.) 
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Figure 34.4: Total expected profit (a) and error rate (b) as a function of the sample size used fo website 
testing. Generated by ab_test_demo.ipynb. 
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In Figure 34.4b, we plot the probability of picking the wrong action vs n. We see that tests that 
48are larger than optimal only reduce this error rate marginally. Consequently, if you want to make 
19the misclassification rate low, you may need a large sample size, particularly if 441 — uo is small, since 
20 then it will be hard to detect the true best action. However, it is also less important to identify 
21the best action in this case, since both actions have very similar expected reward. This explains 
22, why classical methods for A/B testing based on frequentist statistics, which use hypothesis testing 
23 methods to determine if A is better than B, may often recommend sample sizes that are much larger 


24than necessary. (See [FB19] and references therein for further discussion.) 
25 


26 
97 34.4 Contextual bandits 


28 
29 Ihis section was co-authored with Lihong Li. 


30 In Section 34.3, we discussed A/B testing, in which the decision maker tries two different actions, 
3lag and ay, a fixed number of times, nı and no, measures the resulting sequence of rewards, yı and 
32 yo, and then picks the best action to use for the rest of time (or the rest of the population) so as to 
33 maximize expected reward. 

34 We can obviously generalize this beyond two actions. More importantly, we can generalize this 
35 beyond a one-stage decision problem. In particular, suppose we allow the decision maker to try an 
36 action az, observe the reward r;, and then decide what to do at time step t + 1, rather than waiting 
37until nı + no experiments are finished. This immediate feedback allows for adaptive policies 
38that can result in much higher expected reward (lower regret). We have converted a one-stage 
39 decision problem into a sequential decision problem. There are many kinds of sequential decision 
40 problems, but in this section, we consider the simplest kind, known as a bandit problem (see e.g., 
41 [LS19; Sli19]). 

42 


34.4.1 Types of bandit 


45In a multi-armed bandit problem (MAB) there is an agent (decision maker) that can choose an 
46action from some policy a; ~ 7; at each step, after which it receives a reward sampled from the 
47 
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34.4. CONTEXTUAL BANDITS 


environment, r; ~ pr(az), with expected value R(s,a) = E [R]a].? 

We can think of this in terms of an agent at a casino who is faced with multiple slot machines, 
each of which pays out rewards at a different rate. A slot machine is sometimes called a one- 
armed bandit, so a set of K such machines is called a multi-armed bandit; each different action 
corresponds to pulling the arm of a different slot machine, a; € {1,..., K}. The goal is to quickly 
figure out which machine pays out the most money, and then to keep playing that one until you 
become as rich as possible. 

We can extend this model by defining a contextual bandit, in which the input to the policy 
at each step is a randomly chosen state or context s4 E€ S. The states evolve over time according 
to some arbitrary process, s; ~ p(s+|S1:+—1), independent of the actions of the agent. The policy 
now has the form a; ~ 7;(a;|s;), and the reward function now has the form r; ~ pr(r¢|sz, at), with 
expected value R(s,a) = E [R]|s,a]. At each step, the agent can use the observed data, D1, where 
Di = (St, at, rt), to update its policy, to maximize expected reward. 

In the finite horizon formulation of (contextual) bandits, the goal is to maximize the expected 
cumulative reward: 


T 
J © STE pn(rilsesar)m(ailse)p(selere—a) lre] = 9 Efra] (34.36) 


t=1 


+ 
= 


(Note that the reward is accrued at each step, even while the agent updates its policy; this is 
sometimes called “earning while learning”.) In the infinite horizon formulation, where T = co, 
the cumulative reward may be infinite. To prevent J from being unbounded, we introduce a discount 
factor 0 < y < 1, so that 


JEX yE [ry] (34.37) 


The quantity y can be interpreted as the probability that the agent is terminated at any moment in 
time (in which case it will cease to accumulate reward). 
Another way to write this is as follows: 


Co 


J=) V Efr]=) yE 
t=1 


K 
XO Ra(se, a) (34.38) 


+ 
= 


where we define 


R(st,a) ifa;=a 


: (34.39) 
0 otherwise 


Ra(St, at) = 


Thus we conceptually evaluate the reward for all arms, but only the one that was actually chosen 
(namely a;) gives a non-zero value to the agent, namely r. 


2. This is known as a stochastic bandit. It is also possible to allow the reward, and possibly the state, to be chosen 
in an adversarial manner, where nature tries to minimize the reward of the agent. This is known as an adversarial 
bandit. 
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1 
2 There are many extensions of the basic bandit problem. A natural one is to allow the agent to 
3 perform multiple plays, choosing M < K distinct arms at once. Let a; be the corresponding action 
4 vector which specifies the identity of the chosen arms. Then we define the reward to be 
5 K 
: Tt = 5 Ra(St, az) (34.40) 
a a=1 
8 
g where 

0 
a R(s:,a) ifaca 
i aes N < (34.41) 
12 0 otherwise 

3 
j4 This is useful for modeling resource allocation problems. 
T Another variant is known as a restless bandit [Whi88]. This is the same as the multiple play 
yg formulation, except we additionally assume that each arm has its own state vector sł associated with 
17t which evolves according to some stochastic process, regardless of whether arm a was chosen or 
yg Hot. We then define 
19 K 
20 «r= >> Ra(sf, ax) (34.42) 
21 a=1 


2 


& IS | 
w 


where s? ~ p(s?|s{.4_1) is some arbitrary distribution, often assumed to be Markovian. (The fact 
54 that the states associated with each arm evolve even if the arm is not picked is what gives rise to the 
95 berm “restless”.) This can be used to model serial dependence between the rewards given by each 


~arm. 
26 


27 
gg 34.4.2 Applications 


29 Contextual bandits have many applications. For example, consider an online advertising system. 
39Tn this case, the state s; represents features of the web page that the user is currently looking at, and 
31 the action a; represents the identity of the ad which the system chooses to show. Since the relevance 
32 of the ad depends on the page, the reward function has the form R(s;,a¢), and hence the problem 


33is contextual. The goal is to maximize the expected reward, which is equivalent to the expected 


34number of times people click on ads; this is known as the click through rate or CTR. (See e.g., 
35/Gra+10; Li+10; McM+13; Aga+14; Du+21; YZ22] for more information about this application.) 

36 Another application of contextual bandits arises in clinical trials [VBW15]. In this case, the 
3T state s+ are features of the current patient we are treating, and the action a; is the treatment the 
38 doctor chooses to give them (e.g., a new drug or a placebo). Our goal is to maximize expected 
32reward, i.e., the expected number of people who get cured. (An alternative goal is to determine 
which treatment is best as quickly as possible, rather than maximizing expected reward; this variant 


“is known as best-arm identification [ABM10].) 
42 
| 34.4.3 Exploration-exploitation tradeoff 


45 The fundamental difficulty in solving bandit problems is known as the exploration-exploitation 
46 tradeoff. This refers to the fact that the agent needs to try multiple state/action combinations (this 
47 
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34.4. CONTEXTUAL BANDITS 


feedback p impression 
aa 
= 
= WD Ba 2 1 3 
tj (Gana) learning inference Ad-ranking 
Historical Data CTR Model Predictions Ad Winner(s) 


Figure 84.5: Illustration of the feedback problem in online advertising and recommendation systems. The 
click through rate (CTR) model is used to decide what ads to show, which affects what data is collected, which 
affects how the model learns. From Figure 1-2 of [Du+21]. Used with kind permission of Chao Du. 
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Figure 34.6: Illustration of sequential belief updating for a two-armed beta-Bernoulli bandit. The prior for the 
reward for action 1 is the (blue) uniform distribution Beta(1,1); the prior for the reward for action 2 is the 
(orange) unimodal distribution Beta(2,2). We update the parameters of the belief state based on the chosen 
action, and based on whether the observed reward is success (1) or failure (0). 


is known as exploration) in order to collect enough data so it can reliably learn the reward function 
R(s,a); it can then exploit its knowledge by picking the predicted best action for each state. If the 
agent starts exploiting an incorrect model too early, it will collect suboptimal data, and will get 
stuck in a negative feedback loop, as illustrated in Figure 34.5. This is different from supervised 
learning, where the data is drawn iid from a fixed distribution (see e.g.m [Jeu+19] for details). 

We discuss some solutions to the exploration-exploitation problem below. 


34.4.4 The optimal solution 


In this section, we discuss the optimal solution to the exploration-exploitation tradeoff. Let 
us denote the posterior over the parameters of the reward function by bẹ = p(O|h+), where 
hi = {814~-1, 414-1, 714-1} is the history of observations; this is known as the belief state or 
information state. It is a finite sufficient statistic for the history h+. The belief state can be 
updated deterministically using Bayes rule: 


b; = BayesRule(by_1, az, rt) (34.43) 
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1 
2 For example, consider a context-free Bernoulli bandit, where pr(r|a) = Ber(r|ua), and pa = 
3 Pr(r = 1|a) = R(a) is the expected reward for taking action a. Suppose we use a factored beta prior 
a a a 
S po(8) = [] Beta(uulas, 28) (34.44) 
6 a 
7 where 8 = (u1,..., ug). We can compute the posterior in closed form, as we discuss in Section 3.2.1. 
8 In particular, we find 
2 a a 

o  p(OID:) = [J Betalual off + N?(a), 6% + N} (a)) (34.45) 
Ja a a Be 
+2 where 
13 

7 t-i 
i Mla) > I(u =ar =r) (34.46) 
16 s=l 
17 This is illustrated in Figure 34.6 for a two-armed Bernoulli bandit. 
18 We can use a similar method for a Gaussian bandit, where pr(r|a) = N (r| ua, 2), using results 
19from Section 3.2.3. In the case of contextual bandits, the problem becomes more complicated. If 


20we assume a linear regression bandit, pp(r|s,a;0) = N(r\(s,a)'@,07), we can use Bayesian 
21linear regression to compute p(@|D;) in closed form, as we discuss in Section 15.2. If we assume 
22a logistic regression bandit, pr(r|s,a;@) = Ber(r|o((s,a)'@)), we can use Bayesian logistic 
23regression to compute p(@|D;), as we discuss in Section 15.3.4. If we have a neural bandit of 
24the form pr(r|s,a;0) = GLM(r|f(s,a;@)) for some nonlinear function f, then posterior inference 
25 becomes more challenging, as we discuss in Chapter 17. However, standard techniques, such as the 
26extended Kalman filter (Section 17.5.2) can be applied. (For a way to scale this approach to large 
27DNNs, see the “subspace neural bandit” approach of [DMKM22].) 

28 Regardless of the algorithmic details, we can represent the belief state update as follows: 

29 


30 P(b:|br—-1, dt, rt) = I (b; = BayesRule(b;—1, ar, r+)) (34.47) 
31 The observed reward at each step is then predicted to be 

32 

33 p(r,|b:) = J prenls ar; 0)p(0|b;)d0 (34.48) 
34 


35We see that this is a special form of a (controlled) Markov decision process (Section 34.5) known as a 
36 belief-state MDP. 

37 Tn the special case of context-free bandits with a finite number of arms, the optimal policy of this 
38 belief state MDP can be computed using dynamic programming (c.f., Section 34.6); the result can 
3°be represented as a table of action probabilities, 7:(a1,...,a), for each step; this is known as the 
409 Gittins index |Git89]. However, computing the optimal policy for general contextual bandits is 
“intractable [PTS87], so we have to resort to approximations, as we discuss below. 

42 


34.4.5 Upper confidence bounds (UCB) 


45 The optimal solution to explore-exploit is intractable. However, an intuitively sensible approach is 
46based on the principle known as “optimism in the face of uncertainty”. The principle selects 
47 
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34.4. CONTEXTUAL BANDITS 


actions greedily, but based on optimistic estimates of their rewards. The most important class of 

strategies with this principle are collectively called upper confidence bound or UCB methods. 
To use a UCB strategy, the agent maintains an optimistic reward function estimate Ri, so that 

R,(s;,a) > R(s;,a) for all a with high probability, and then chooses the greedy action accordingly: 


a, = argmax R;(s;, a) (34.49) 


UCB can be viewed a form of exploration bonus, where the optimistic estimate encourages 
exploration. Typically, the amount of optimism, R, — R, decreases over time so that the agent 
gradually reduces exploration. With properly constructed optimistic reward estimates, the UCB 
strategy has been shown to achieve near-optimal regret in many variants of bandits [LS19]. (We 
discuss regret in Section 34.4.7.) 

The optimistic function R can be obtained in different ways, sometimes in closed forms, as we 
discuss below. 


34.4.5.1 Frequentist approach 


One approach is to use a concentration inequality [BLM16] to derive a high-probability upper 
bound of the estimation error: |R;(s,a) — R;(s,a)| < 6;(s,a), where Ê; is a usual estimate of R 
(often the MLE), and 4; is a properly selected function. An optimistic reward is then obtained by 
setting R,(s,a) = R,(s,a) + 5:(s, a). 

As an example, consider again the context-free Bernoulli bandit, R(a) ~ Ber(u(a)). The MLE 
R,(a) = û:(a) is given by the empirical average of observed rewards whenever action a was taken: 


A — N} (a) — N;} (a) 
ela) = yay NE + NEC) 


(34.50) 


where N? (a) is the number of times (up to step t — 1) that action a has been tried and the observed 
reward was r, and N;(a) is the total number of times action a has been tried: 


t—1 

Nila) = I (a =a) (34.51) 

s=1 

Then the Chernoff-Hoeffding inequality [BLM16] leads to 6;(a) = c/,/N:(a) for some proper 
constant c, so 

~ c 

Rila) = Arla) + —— (34.52) 

N:(a) 


34.4.5.2 Bayesian approach 


We may also derive R from Bayesian inference. If we use a beta prior, we can compute the posterior 
in closed form, as shown in Equation (34.45). The posterior mean is fi,(a) = E [u(a)|h;] = 


a 
Oy 
af +6? 


From Equation (3.20), the posterior standard deviation is approximately 


= N [ulah] © j oge 1 — f(a) (34.53) 
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Figure 84.7: Illustration of the reward distribution Q(a) for 3 different actions, and the corresponding lower 
and upper confidence bounds. From [Sil18]. Used with kind permission of David Silver. 


We can use similar techniques for a Gaussian bandit, where pr(R\ja,@) = N(R] ua, 02), Ha is the 
expected reward, and o? the variance. If we use a conjugate prior, we can compute p(Ha, Cal De) 
in closed form (see Section 3.2.3). Using an uninformative version of the conjugate prior, we find 
ù [Halha] = f(a), which is just the empirical mean of rewards for action a. The uncertainty in this 
estimate is the standard error of the mean, given by Equation (3.66), i.e., V V [talh] = o4(a)/./ Ni (a), 
where 6;(a) is the empirical standard deviation of the rewards for action a. 

This approach can also be extended to contextual bandits, modulo the difficulty of computing the 
93 belief state. 
~~ Once we have computed the mean and posterior standard deviation, we define the optimistic 
5 reward estimate as 


SIS |S ls le If le la le le Is IF ls 


m R,(a) = (a) + cô,(a) (34.54) 


28 for some constant c that controls how greedy the policy is. We see that this is similar to the frequentist 
29method based on concentration inequalities, but is more general. 
30 


3134,4.5.3 Example 

32 

33 Figure 34.7 illustrates the UCB principle for a Gaussian bandit. We assume there are 3 actions, and 
34We represent p(R(a)|D;) using a Gaussian. We show the posterior means Q(a) = (a) with a vertical 
35 dotted line, and the scaled posterior standard deviations co(a) as a horizontal solid line. 

36 


3734.4.6 Thompson sampling 


38 
7-A common alternative to UCB is to use Thompson sampling [Tho33], also called probability 
-~ matching [Scol0]. In this approach, we define the policy at step t to be m;(a|s:, ht) = pa, where pa 


is the probability that a is the optimal action. This can be computed using 

42 

43. Pa = Pr(a = as|s, hi) = fi (« = argmax R(s;, a’; 0)) p(O|h,)dO (34.55) 
4 a 


45 If the posterior is uncertain, the agent will sample many different actions, automatically resulting in 
46exploration. As the uncertainty decreases, it will start to exploit its knowledge. 
47 
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Figure 34.8: Illustration of Thompson sampling applied to a linear-Gaussian contextual bandit. The context 
has the form st = (1,t,t°). (a) True reward for each arm vs time. (b) Cumulative reward per arm vs time. 
(c) Cumulative regret vs time. Generated by thompson_sampling_linear_ gaussian. ipynb. 


To see how we can implement this method, note that we can compute the expression in Equa- 
tion (34.55) by using a single Monte Carlo sample 6; ~ p(6|h,). We then plug in this parameter into 
our reward model, and greedily pick the best action: 


ay = argmax R(s;, a’; 6) (34.56) 


a 


This sample-then-exploit approach will choose actions with exactly the desired probability, since 


Pa = fi (« = argmax R(s;, 0’; 8) plðilh:)=_ Pr (a= argmax R(s:,a'; 0:)) (34.57) 
a’ O:~p(O|hz) a’ 


Despite its simplicity, this approach can be shown to achieve optimal (logarithmic) regret (see 
e.g., [Rus+18] for a survey). In addition, it is very easy to implement, and hence is widely used in 
practice [Gra+10; Scol0; CL11]. 

In Figure 34.8, we give a simple example of Thompson sampling applied to a linear regression bandit. 
The context has the form s; = (1,t,t?). The true reward function has the form R(s;,a) = wl s+. The 
weights per arm are chosen as follows: wo = (—5, 2,0.5), wi = (0,0,0), we = (5, —1.5,—1). Thus we 
see that arm 0 is initially worse (large negative bias) but gets better over time (positive slope), arm 1 
is useless, and arm 2 is initially better (large positive bias) but gets worse over time. The observation 
noise is the same for all arms, g? = 1. See Figure 34.8(a) for a plot of the reward function. 

We use a conjugate Gaussian-Gamma prior and perform exact Bayesian updating. Thompson 
sampling quickly discovers that arm 1 is useless. Initially it pulls arm 2 more, but it adapts to the 
non-stationary nature of the problem and switches over to arm 0, as shown in Figure 34.8(b). 


34.4.7 Regret 


We have discussed several methods for solving the exploration-exploitation tradeoff. It is useful 
to quantify the degree of suboptimality of these methods. A common approach is to compute the 
regret, which is defined as the difference between the expected reward under the agent’s policy and 
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the oracle policy mx, which knows the true reward function. (Note that the oracle policy will in 
general be better than the Bayes optimal policy, which we disucssed in Section 34.4.4.) 
Specifically, let 7; be the agent’s policy at time t. Then the per-step regret at t is defined as 


lt > Epis.) [R(t 7» (8t))] — En,(as|s:)p(s2) [R(S¢, a)l (34.58) 


If we only care about the final performance of the best discovered arm, as in most optimization 
problems, it is enough to look at the simple regret at the last step, namely lr. Optimizing simple 
regret results in a problem known as pure exploration [BMS11], since there is no need to exploit 
iothe information during the learning process. However, it is more common to focus on the the 
1ucumulative regret, also called the total regret or just the regret, which is defined as 


T 
Lr SE 5 | (34.59) 


Here the expectation is with respect to randomness in determining m+, which depends on earlier 
states, actions and rewards, as well as other potential sources of randomness. 

Under the typical assumption that rewards are bounded, Dr is at most linear in T. If the agent’s 
18 policy converges to the optimal policy as T increases, then the regret is sublinear: Lr = 0(T). In 
19 seneral, the slower Lr grows, the more efficient the agent is in trading off exploration and exploitation. 


IO 100 IN ID Jo Ie lo IN Ie 


20 To understand its growth rate, it is helpful to consider again a simple context-free bandit, where 
21 R, = argmax, R(a) is the optimal reward. The total regret in the first T steps can be written as 
22 

on T 

23 ` ‘i a 

x4 LT=E > R, — R(a)| = > [Nr+i(a)] (Rx — R(a)) = > [Nr+i(a)] Aa (34.60) 
ae t=1 acA acA 


26 where Nr4+ı(a) is the total number of times the agent picks action a up to step T, and A, = R+ — R(a) 
27is the reward gap. If the agent under-explores and converges to choosing a suboptimal action (say, 
234), then a linear regret is suffered with a per-step regret of Aa. On the other hand, if the agent 
29 over-explores, then N;(a) will be too large for suboptimal actions, and the agent also suffers a linear 
30 regret. 

31 Fortunately, it is possible to achieve sublinear regrets, using some of the methods discussed above, 
32such as UCB and Thompson sampling. For example, one can show that Thompson sampling has 
33 O(./KT logT) regret [RR14]. This is shown empirically in Figure 34.8(c). 

34 In fact, both UCB and Thompson sampling are optimal, in the sense that their regrets are 
35essentially not improvable; that is, they match regret lower bounds. To establish such a lower bound, 
36note that the agent needs to collect enough data to distinguish different reward distributions, before 
37identifying the optimal action. Typically, the deviation of the reward estimate from the true reward 
3sdecays at the rate of 1/N, where N is the sample size (see e.g., Equation (3.66)). Therefore, if 
39two reward distributions are similar, distinguishing them becomes harder and requires more samples. 
4o(For example, consider the case of a bandit with Gaussian rewards with slightly different means and 
ailarge variance, as shown in Figure 34.7.) 

42 The following fundamental result is proved by [LR85] for the asymptotic regret (under certain mild 
43assumptions not given here): 


a - 

a5 liminf Lr > logT å (34.61) 
T a 2 Dxu (pr(a) || pr(as)) 

47 
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34.5. MARKOV DECISION PROBLEMS 


Figure 84.9: Illustration of an MDP as a finite state machine (FSM). The MDP has three discrete states 
(green cirlces), two discrete actions (orange circles), and two non-zero rewards (orange arrows). The numbers 


on the black edges represent state transition probabilities, e.g., p(s’ sola = ao, s’ so) 0.7; most 
state transitions are impossible (probability 0), so the graph is sparse. The numbers on the yellow wiggly 
edges represent expected rewards, e.g., R(s = s1,a = 0,8’ = 80) +5; state transitions with zero reward 


are not annotated. From https: //en. wikipedia. org/wiki/Markov_ deciston_ process. Used with kind 
permission of Wikipedia author waldoalvarez. 


Thus, we see that the best we can achieve is logarithmic growth in the total regret. Similar lower 
bounds have also been obtained for various bandits variants. 


34.5 Markov decision problems 


In this section, we generalize the discussion of contextual bandits by allowing the state of nature 
to change depending on the actions chosen by the agent. The resulting model is called a Markov 
decision process or MDP, as we explain in detail below. This model forms the foundation of 
reinforcement learning, which we discuss in Chapter 35. 


34.5.1 Basics 


A Markov decision process [Put94] can be used to model the interaction of an agent and an 
environment. It is often described by a tuple (S, A, PT, PR, po}, where S is a set of environment 
states, A a set of actions the agent can take, pr a transition model, pr a reward model, and po 
the initial state distribution. The interaction starts at time t = 0, where the initial state so ~ po. 
Then, at time t > 0, the agent observes the environment state s; E€ S, and follows a policy 7 to 
take an action a; € A. In response, the environment emits a real-valued reward signal r; € R and 
enters a new state s:41 E€ S. The policy is in general stochastic, with z(a|s) being the probability of 
choosing action a in state s. We use 7(s) to denote the conditional probability over A if the policy 
is stochastic, or the action it chooses if it is deterministic. The process at every step is called a 
transition; at time t, it consists of the tuple (St, at, rt, 5441), where a, ~ 7(S¢), 5:41 ~ pr(S¢, at), 
and rt ~ pr(Sz, at, $141). Hence, under policy 7, the probability of generating a trajectory T of 
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length T can be written explicitly as 


T-1 
p(T) = po(So) II T(az|Se) Pr (St41|Se, 4) PR(Te| Se, at, St+1) (34.62) 
t=0 


It is useful to define the reward function from the reward model pr, as the average immediate 
reward of taking action a in state s, with the next state marginalized: 


0 R(s, a) 3 “nr (s/|s,a) [ Čp(r|s,a,s’) [r]] (34.63) 
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"Eliminating the dependence on next states does not lead to loss of generality in the following 
~ discussions, as our subject of interest is the total (additive) expected reward along the trajectory. 
“For this reason, we often use the tuple (S, A, pr, R, po) to describe an MDP. 

ma In general, the state and action sets of an MDP can be discrete or continuous. When both sets are 
ignite, we can represent these functions as lookup tables; this is known as a tabular representation. 
aP this case, we can represent the MDP as a finite state machine, which is a graph where nodes 
1g correspond to states, and edges correspond to actions and the resulting rewards and next states. 
— Figure 34.9 gives a simple example of an MDP with 3 states and 2 actions. 

7 The field of control theory, which is very closely related to RL, uses slightly different terminology. 
~mn particular, the environment is called the plant, and the agent is called the controller. States are 
Ss = denoted by a € X C RP, actions are denoted by u, € U C R*, and rewards are denoted by costs 
oat € R. Apart from this iG ational difference, the fields of RL and control theory are very similar 
54 ee e.g., [Son98; Rec19]), although control theory tends to focus on provably optimal methods (by 
og Making strong modeling assumptions), whereas RL tends to tackle harder problems with heuristic 
9g icthods, for which optimality guarantees are often hard to obtain. 


27 
gg 34.5.2 Partially observed MDPs 


29 An important generalization of the MDP framework relaxes the assumption that the agent sees the 
3°hidden world state s+ directly; instead we assume it only sees a potentially noisy observation generated 
31 from the hidden state, 2, ~ p(-|s+, a+). The resulting model is called a partially observable Markov 
32 decision process or POMDP (pronounced “pom-dee-pee”). Now the agent’s policy is a mapping 
33 from all the available data to actions, a; ~ t(D14-1, £4), Di = (£1, at, rt). See Figure 34.10 for an 
34 illustration. MDPs are a special case where 2; = s+. 

35 In general, POMDPs are much harder to solve than MDPs. A common approximation is to use 
36 the last several observed inputs, say £+—h:t for history of size h, as a proxy for the hidden state, and 


37then to treat this as a fully observed MDP. 
38 


= 34.5.3 Episodes and returns 


41 The Markov decision process describes how a trajectory T = (s0, ao, ro, S1, 41,11,---) is stochastically 
42 generated. If the agent can potentially interact with the environment forever, we call it a continuing 
43task. Alternatively, the agent is in an episodic task, if its interaction terminates once the system 
44enters a terminal state or absorbing state; s is absorbing if the next state from s is always s 
45 with 0 reward. After entering a terminal state, we may start a new epsiode from a new initial state 
46 So ~ po. The episode length is in general random. For example, the amount of time a robot takes to 
47 
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Figure 34.10: Illustration of a partially observable Markov decision process (POMDP) with hidden environment 
state s; which generates the observation x+, controlled by an agent with internal belief state b which generates 
the action az. The reward r; depends on s+ and at. Nodes in this graph represent random variables (circles) 
and decision variables (squares). 


reach its goal may be quite variable, depending on the decisions it makes, and the randomness in the 
environment. Note that we can convert an episodic MDP to a continuing MDP by redefining the 
transition model in absorbing states to be the initial-state distribution po. Finally, if the trajectory 
length T in an episodic task is fixed and known, it is called a finite horizon problem. 

Let T be a trajectory of length T, where T may be oo if the task is continuing. We define the 
return for the state at time t to be the sum of expected rewards obtained going forward, where each 
reward is multiplied by a discount factor y € [0, 1]: 

Gi 2 Tt + YTt+1 + 7 r42 Se Pt irr (34.64) 
T-t-1 


T-1 
SS rin = Sor; (34.65) 
k=0 j=t 


G; is sometimes called the reward-to-go. For episodic tasks that terminate at time T, we define 
G = 0 for t > T. Clearly, the return satisfies the following recursive relationship: 


Ge = re + V(reqa + Tego +--+) = ri + VG e411 (34.66) 


The discount factor y plays two roles. First, it ensures the return is finite even if T = oo (i.e., 
infinite horizon), provided we use 7 < 1 and the rewards r; are bounded. Second, it puts more weight 
on short-term rewards, which generally has the effect of encouraging the agent to achieve its goals 
more quickly (see Section 34.5.5.1 for an example). However, if 7 is too small, the agent will become 
too greedy. In the extreme case where y = 0, the agent is completely myopic, and only tries to 
maximize its immediate reward. In general, the discount factor reflects the assumption that there 
is a probability of 1 — y that the interaction will end at the next step. For finite horizon problems, 
where T is known, we can set y = 1, since we know the life time of the agent a priori.’ 


3. We may also use y = 1 for continuing tasks, targeting the (undiscounted) average reward criterion [Put94]. 
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1 

2 34.5.4 Value functions 

3 Let m bea given policy. We define the state-value function, or value function for short, as 
* follows (with E, [|] indicating that actions are selected by 7): 

5 

6 °° 

z Vn(s) = Ex [Go]so = $] = Ex È q'rilso = | (34.67) 
3 t=0 

9 This is the expected return obtained if we start in state s and follow m to choose actions in a 
19 continuing task (i.e., T = 00). 

11 Similarly, we define the action-value function, also known as the Q-function, as follows: 

12 io 

13 Qz(s,a) = Ex [Go|so = s, ao = a] = E, bs y rilso = 8, ao = a| (34.68) 
14 t=0 

15 This quantity represents the expected return obtained if we start by taking action a in state s, and 
16 then follow 7 to choose actions thereafter. 

17 Finally, we define the advantage function as follows: 

18 

19 Ax(s,a) £ Q,(s,a) — Vz (8) (34.69) 


20 This tells us the benefit of picking action a in state s then switching to policy 7, relative to the 


2l baseline return of always following 7. Note that A,(s,a) can be both positive and negative, and 


22 ir (aļs) [Ax(s,@)] = 0 due to a useful equality: V;(s) = Ex(ajs) [Qr (s, a)]. 
23 


~ 34.5.5 Optimal value functions and policies 


26 Suppose 7, is a policy such that Vr, > Vr for all s € S and all policy 7, then it is an optimal 
27policy. There can be multiple optimal policies for the same MDP, but by definition their value 
28functions must be the same, and are denoted by V, and Q,, respectively. We call V, the optimal 
29state-value function, and Q, the optimal action-value function. Furthermore, any finite MDP 
30must have at least one deterministic optimal policy [Put94]. 

31 A fundamental result about the optimal value function is Bellman’s optimality equations: 
32 


33 V; (s) = max R(s, a) +7 Upp (s!|s,a) [V. (s’)] (34.70) 
34 Q,.(s,a) = R(s,a) + y max Epp (s’]s,a) (Q.(s’,a’)| (34.71) 
35 a 


36 Conversely, the optimal value functions are the only solutions that satisfy the equations. In other 
37words, although the value function is defined as the expectation of a sum of infinitely many rewards, 
38it can be characterized by a recursive equation that involves only one-step transition and reward 
39models of the MDP. Such a recursion play a central role in many RL algorithms we will see later 
4oin this chapter. Given a value function (V or Q), the discrepancy between the right- and left-hand 
4isides of Equations (34.70) and (34.71) are called Bellman error or Bellman residual. 

42 Furthermore, given the optimal value function, we can derive an optimal policy using 

43 


a T(S) = argmax Q..(s, a) (34.72) 
j = argmax |R(s, a) + YEpr(s'ls,a) [V«(3")]] (34.73) 
47 
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34.6. PLANNING IN AN MDP 


Following such an optimal policy ensures the agent achieves maximum expected return starting from 
any state. The problem of solving for Vs, Q, or Tm, is called policy optimization. In contrast, 
solving for Vp or Q, for a given policy z is called policy evaluation, which constitutes an important 
subclass of RL problems as will be discussed in later sections. For policy evaluation, we have similar 
Bellman equations, which simply replace max,{-} in Equations (34.70) and (34.71) with Er(aļs) I]. 

In Equations (34.72) and (34.73), as in the Bellman optimality equations, we must take a maximum 
over all actions in A, and the maximizing action is called the greedy action with respect to the 
value functions, Q, or V,. Finding greedy actions is computationally easy if A is a small finite 
set. For high dimensional continuous spaces, we can treat a as a sequence of actions, and optimize 
one dimension at a time [Met+17], or use gradient-free optimizers such as cross-entropy method 
(Section 6.9.5), as used in the QT-Opt method [Kal+18a]. Recently, CAQL (continuous action 
Q-learning, [Ryu+20]) proposed to use mixed integer programming to solve the argmax problem, 
leveraging the ReLU structure of the @-network. We can also amortize the cost of this optimization 
by training a policy a, = 7.(s) after learning the optimal Q-function. 


34.5.5.1 Example 


In this section, we show a simple example, to make concepts like value functions more concrete. 
Consider the 1d grid world shown in Figure 34.11(a). There are 5 possible states, among them S71 
and Sr2 are absorbing states, since the interaction ends once the agent enters them. There are 2 
actions, f and |. The reward function is zero everywhere except at the goal state, S72, which gives a 
reward of 1 upon entering. Thus the optimal action in every state is to move down. 

Figure 34.11(b) shows the Q, function for y = 0. Note that we only show the function for 
non-absorbing states, as the optimal @-values are 0 in absorbing states by definition. We see that 
Q«(s3,}) = 1.0, since the agent will get a reward of 1.0 on the next step if it moves down from s3; 
however, Q..(s,a) = 0 for all other state-action pairs, since they do not provide nonzero immediate 
reward. This optimal @-function reflects the fact that using y = 0 is completely myopic, and ignores 
the future. 

Figure 34.11(c) shows Q, when y = 1. In this case, we care about all future rewards equally. Thus 
Q..(s,a) = 1 for all state-action pairs, since the agent can always reach the goal eventually. This is 
infinitely far-sighted. However, it does not give the agent any short-term guidance on how to behave. 
For example, in s2, it is not clear if it is should go up or down, since both actions will eventually 
reach the goal with identical Q,-values. 

Figure 34.11(d) shows Q, when y = 0.9. This reflects a preference for near-term rewards, while 
also taking future reward into account. This encourages the agent to seek the shortest path to the 
goal, which is usually what we desire. A proper choice of y is up to the agent designer, just like the 
design of the reward function, and has to reflect the desired behavior of the agent. 


34.6 Planning in an MDP 


In this section, we discuss how to compute an optimal policy when the MDP model is known. This 
problem is called planning, in contrast to the learning problem where the models are unknown, 
which is tackled using reinforcement learning Chapter 35. The planning algorithms we discuss are 
based on dynamic programming (DP) and linear programming (LP). 

For simplicity, in this section, we assume discrete state and action sets with y < 1. However, exact 
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1 
2 Q*(s, a) 
3 R(s) 
4 y=0 y= y =0.9 
2 ST1 Up Down Up Down Up Down 
6 0 
z S 
r 10 0 0 o | 1.0 o | 0.81 
2 S2 o 
10 0 0 1.0 1.0 0.73 0.9 
1 
is S3 o 
T 0 1.0 1.0 1.0 0.81 1.0 
= ST2 1 
14 
5 
k (a) (b) (c) (d) 
7 
ig igure 34.11: Left: illustration of a simple MDP corresponding to a 1d grid world of 3 non-absorbing states 
19 24 2 actions. Right: optimal Q-functions for different values of y. Adapted from Figures 3.1, 3.2, 3.4 of 


19 (GK 19]. 
20 
2 


= calculation of optimal policies often depends polynomially on the sizes of S and A, and is intractable, 
23 for example, when the state space is a Cartesian product of several finite sets. This challenge is known 
2525 the curse of dimensionality. Therefore, approximations are typically needed, such as using 
~ parametric or nonparametric representations of the value function or policy, both for computational 
~~ tractability and for extending the methods to handle MDPs with general state and action sets. 
mi this case, we have approximate dynamic programming (ADP) and approximate linear 
ng Programming (ALP) algorithms (see e.g., [Ber19]). 


3134.6.1 Value iteration 
32 A popular and effective DP method for solving an MDP is value iteration (VI). Starting from an 


33 initial value function estimate Vo, the algorithm iteratively updates the estimate by 
34 


35 
36 Vke+i (8) = max R(s,a) + yX p(s']s, a)V;,(s’) (34.74) 
37 í 


38 Note that the update rule, sometimes called a Bellman backup, is exactly the right-hand side of 
3®the Bellman optimality equation Equation (34.70), with the unknown V, replaced by the current 
estimate Vp. A fundamental property of Equation (34.74) is that the update is a contraction: it 


4lcan be verified that 
42 


43 max |Ve41(s) — V.(s)| < ymax |Ve(s) — V.(s)| (34.75) 
44 
45In other words, every iteration will reduce the maximum value function error by a constant factor. 
46It follows immediately that V; will converge to V,, after which an optimal policy can be extracted 
47 
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using Equation (34.73). In practice, we can often terminate VI when Vp is close enough to V+, since 
the resulting greedy policy wrt Vp will be near optimal. Value iteration can be adapted to learn the 
optimal action-value function Q.. 

In value iteration, we compute V,(s) and 7,(s) for all possible states s, averaging over all possible 
next states s’ at each iteration, as illustrated in Figure 34.12(right). However, for some problems, 
we may only be interested in the value (and policy) for certain special starting states. This is the 
case, for example, in shortest path problems on graphs, where we are trying to find the shortest 
route from the current state to a goal state. This can be modeled as an episodic MDP by defining a 
transition matrix pr(s’|s,a) where taking edge a from node s leads to the neighboring node s’ with 
probability 1. The reward function is defined as R(s,a) = —1 for all states s except the goal states, 
which are modeled as absorbing states. 

In problems such as this, we can use a method known as real-time dynamic programming 
(RTDP) [BBS95], to efficiently compute an optimal partial policy, which only specifies what to do 
for the reachable states. RTDP maintains a value function estimate V. At each step, it performs a 
Bellman backup for the current state s by V(s) + max, Epr(s'|s,a) [R(s,a) + yV(s‘)]. It can picks an 
action a (often with some exploration), reaches a next state s’, and repeats the process. This can be 
seen as a form of the more general asynchronous value iteration, that focuses its computational 
effort on parts of the state space that are more likely to be reachable from the current state, rather 
than synchronously updating all states at each iteration. 


34.6.2 Policy iteration 


Another effective DP method for computing m, is policy iteration. It is an iterative algorithm that 
searches in the space of deterministic policies until converging to an optimal policy. Each iteration 
consists of two steps, policy evaluation and policy improvement. 

The policy evaluation step, as mentioned earlier, computes the value function for the current 
policy. Let m represent the current policy, v(s) = V,(s) represent the value function encoded as 
a vector indexed by states, r(s) = `, 7(als)R(s,a) represent the reward vector, and T(s’|s) = 
Xa 7(a|s)p(s'|s, a) represent the state transition matrix. Bellman’s equation for policy evaluation 
can be written in the matrix-vector form as 


v=r+y7Tv (34.76) 


This is a linear system of equations in |S| unknowns, We can solve it using matrix inversion: 
v =(I-yT)7!r. Alternatively, we can use value iteration by computing v4.41 = r + yTv; until near 
convergence, or some form of asynchronous variant that is computationally more efficient. 

Once we have evaluated V, for the current policy 7, we can use it to derive a better policy 7’, thus 
the name policy improvement. To do this, we simply compute a deterministic policy 7’ that acts 
greedily with respect to V, in every state; that is, 7’(s) = argmax,{R(s,a) + yE[V,(s’)]}. We can 
guarantee that Vy > Vp. To see this, define r’, T’ and v’ as before, but for the new policy z’. The 
definition of 7’ implies r’ + yT’v > r+ yTv = v, where the equality is due to Bellman’s equation. 
Repeating the same equality, we have 


vcr tyT'v <r’ +T (r' +9T'v) <r! +T (r + T(r’ +YT'v)) <- (34.77) 
= (1471 +T? +- )r = (L-7T') ir =v (34.78) 
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Figure 34.12: Policy iteration vs value iteration represented as backup diagrams. Empty circles represent 
states, solid (filled) circles represent states and actions. Adapted from Figure 8.6 of [SB18]. 


Starting from an initial policy 79, policy iteration alternates between policy evaluation (E) and 
improvement (J) steps, as illustrated below: 


10 => Vy P11 => Ving m Ve (34.79) 


aT he algorithm stops at iteration k, if the policy mp is greedy with respect to its own value function 
w: In this case, the policy is optimal. Since there are at most |A|'S! deterministic policies, and 
23 CVelry iteration strictly improves the policy, the algorithm must converge after finite iterations. 
A In PI, we alternate between policy evaluation (which involves multiple iterations, until convergence 
ack Vz), and policy improvement. In VI, we alternate between one iteration of policy evaluation followed 
26 2Y one iteration of policy improvement (the “max” operator in the update rule). In generalized 
əz Policy improvement, we are free to intermix any number of these these steps in any order. The 
ag Process will converge once the policy is greedy wrt its own value function. 

Note that policy evaluation computes V, whereas value iteration computes V,. This difference is 
zo Hustrated in Figure 34.12, using a backup diagram. Here the root node represents any state s, 
31 Hodes at the next level represent state-action combinations (solid circles), and nodes at the leaves 
39 representing the set of possible resulting next state s’ for each possible action. In the former case, we 
33 average over all actions according to the policy, whereas in the latter, we take the maximum over all 
3q actions. 
35 
3634.6.3 Linear programming 


3T While dynamic programming is effective and popular, linear programming (LP) provides an alternative 
“that finds important uses, such as in off-policy RL (Section 35.5). The primal form of LP is given by 
40 : 1 

x aD Me) s.t. V(s) > R(s,a) + yX pr(s |s,a)V(s), V(s,a)ESxXA (34.80) 
42 

43 where po(s) > 0 for all s € S, and can be interpreted as the initial state distribution. It can be 
44verified that any V satisfying the constraint in Equation (34.80) is optimistic [Put94], that is, V > V,. 
45 When the objective is minimized, the solution V will be “pushed” to the smallest possible, which is 
46 V,. Once V, is found, any action a that makes the constraint tight in state s is optimal in that state. 
47 
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34.7. ACTIVE LEARNING 


The dual LP form is sometimes more intuitive: 


max d(s,a)R(s,a) s.t. S > d(s,a) = (1—7)po(s) +y pr(s|5,ā)d(5,ā) Vs eS (34.81) 


s8,a 5,a 
Any nonnegative d satisfying the constraint above is the normalized occupancy distribution of 
some corresponding policy ta(a|s) = d(s,a)/>>,,d(s,a’): + 


Co 


d(s,a) = (1-74) So y'p(se = s, at = also ~ po, at ~ Ta(Sz)) (34.82) 
t=0 


The constant (1 — y) normalizes d to be a valid distribution, so that it sums to unity. With this 
interpretation of d, the objective in Equation (34.81) is just the average per-step reward under the 
normalized occupancy distribution. Once an optimal solution d, is found, an optimal policy can be 
immediately obtained by 7,.(a|s) = d.(s,a)/ X ar dx(s, a’). 

A challenge in solving the primal or dual LPs for MDPs is the large number of constraints and 
variables. Approximations are needed, where the variables are parameterized (either linearly or 
nonlinearly), and the the constraints are sampled or approximated (see e.g., [dV04; LBS17; CLW18]). 


34.7 Active learning 


This section is coauthored with Zeel B Patel. 


In this section, we discuss active learning (AL), in which the agent gets to choose which data it 
wants to use so as to learn the underlying predictive function as quickly as possible, i.e., using the 
smallest amount of labeled data. This can be much more efficient than using randomly collected data, 
as illustrated in Figure 34.13. This is useful when labels are expensive to collect, e.g., for medical 
image classification [GIG17; Wal+20]. 

There are many approaches to AL, as reviewed in [Set12]. In this section, we just consider a few 
methods. 


34.7.1 Active learning scenarios 


One of the earliest AL methods is known as membership query synthesis [Ang88]. In this scenario 
the agent can generate an arbitrary query Œ ~ p(x) and then ask the oracle for its label, y = f(a). 
(An “oracle” is the term given to a system that knows the true answer to every possible question.) 
This scenario is mostly of theoretical interest, since it is hard to learn good generative models, and it 
is rarely possible to have access to an oracle on demand (although human-power crowd computing 
platforms can be considered as oracles with high latency). 

Another scenario is stream-based selective sampling [ACL89], where the agent receives a 
stream of inputs, £1, £2,..., and at each step must decide whether to request the label or not. Again, 
this scenario is mostly of theoretical interest. 

The last and widely used setting for machine learning is pool-based-sampling [LG94], where the 
pool of unlabeled samples ¥ is available from the beginning. At each step we apply an acquisition 


4. If >, d(s,a’) = 0 for some state s, then 7q(s) may be defined arbitrarily, since s is not visited under the policy. 
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8 
9 Feature 1 Feature 1 Feature 1 
1 (a) (b) (c) 
1 
12 Figure 34.13: Decision boundaries for a logistic regression model applied to a 2-dimensional, 3-class dataset. 
13(a) Results after fitting the model on the initial training data; the test accuracy is 0.818. (b) results 
safter further training on 11 randomly sampled points; accuracy is 0.848. (c) Results after further train- 
15 ing on 11 points chosen with margin sampling (see Section 34.7.3); accuracy is 0.969. Generated by 
ie active_learning visualization class. ipynb. 
17 
18 Problem Goal Action space 
Active Learning argmin r Ep(æ) E(f* (æ), f(x))] choose x at which to get y = f*(æ) 
= Bayesian optimization argmax,.,y f*(æ) choose æ at which to evaluate f*(x) 
21 Contextual bandits argmax, En(x)r(ala) |R*(#,a)] choose a at which to evaluate R* (a, a) 
22 


23 Table 34.1: Comparison among active learning, Bayesian optimization and contextual bandits in terms of 
24 goal and action space 


27 

28 function to each candidate in the batch, to decide which one to collect the label for. We then collect 
29the label, update the model with the new data, and repeat the process until we exhaust the pool, 
30run out of time, or reach some desired performance. In the subsequent sections, we will focus only 
310n pool-based sampling. 

32 


34.7.2 Relationship to other forms of sequential decision making 


35(Pool-based) active learning is closely related to Bayesian optimization (BO, Section 6.8) and contexual 
36 bandit problems (Section 34.4). The connections are discussed at length [Tou14], but in brief, the 
37methods differ because they solve slightly different objective functions, as summarized in Table 34.1. 
38In particular, in active learning, our goal is to identify a function f : ¥ — Y that will incur minimum 
39expected loss when applied to random inputs x; in BO, our goal is to identify an input point æ where 
40the function output f(a) is maximal; and in bandits, our goal is to identify a policy 7: ¥ > A that 
41 will give maximum expected reward when applied to random inputs (contexts) a. (We see that the 
42 goal in AL and bandits is similar, but in bandits the agent only gets to choose the action, not the 
43 state, so only has partial control over where the (reward) function is evaluated.) 

44 In all three problems, we want to find the optimum with as few actions as possible, so we have 
45to solve the exploration-exploitation problem (Section 34.4.3). One approach is to represent our 
46 uncertainty about the function using a method such as a Gaussian process (Chapter 18), which lets 
47 
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(a) Active learning (b) Bayesian optimization 


Figure 34.14: Active learning vs Bayesian optimization. Active learning tries to approximate the true function 
well. Bayesian optimization tries to find maximum value of the true function. Initial and queried points are 
denoted as black and red dots respectively. Generated by bayes_opt_vs_ active_learning.ipynb. 


us compute p(f|D1). We then define some acquisition function a(a) that evaluates how useful it 
would be to query the function at input location æ, given the belief state p(f|D .,) and we pick as 
our next query 2441 = argmax, a(x). (In the bandit setting, the agent does not get to choose the 
state x, but does get to choose action a.) For example, in BO, it is common to use probability of 
improvement (Section 6.8.3.1), and for AL of a regression task, we can use the posterior predictive 
variance. The objective for AL will cause the agent to query “all over the place”, whereas for BO, 
the agent will “zoom in” on the most promising regions, as shown in Figure 34.14. We discuss other 
acquisition functions for AL in Section 34.7.3. 


34.7.3 Acquisition strategies 


In this section, we discuss some common AL heuristics for choosing which points to query. 


34.7.3.1 Uncertainty sampling 


An intuitive heuristic for choosing which example to label next is to pick the one for which the model 
is currently most uncertain. This is called uncertainty sampling. We already illustrated this in 
the case of regression in Figure 34.14, where we represented uncertainty in terms of the posterior 
variance. 

For classification problems, we can measure uncertainty in various ways. Let pn = [p(y = ¢lan)|— 
be the vector of class probabilities for each unlabeled input £n. Let Un = a(pn) be the uncertainty for 
example n, where a is an acquisition function. Some common choices for a are: entropy sampling 
[SW87a], which uses a(p) = — i pe log pe; Margin sampling, which uses a(p) = p2 — pi, where 
pı is the probability of the most probable class, and pə is the probability of the second most probable 
class; and least confident sampling, which uses a(p) = 1 — pe, where c* = argmax,p,. The 
difference between these strategies is shown in Figure 34.15. In practice it is often found that margin 
sampling works the best [Chu+19]. 
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H Figure 34.15: Outputs of a logistic regression model fit on some training points, and then applied to 3 candidate 
— query inputs. We show the predicted probabilites for each class label. The highlighted dark gray is the max 
1 probability, the light gray bar is the 2nd highest probability. The least confident scores for the 3 inputs are: 
2 1.0.23=0. 76, 1-0.25=0.75, and 1-0.47=0.58, so we pick the first query. The entropy scores are: 1.63, 1.78 and 
210.89, so we pick the second query. The margin scores are: 0.237-0.2067=0.0308, 0.2513-0.2277=0.0236, and 
22 0.4689-0.4687=0.0002, so we pick the third query. Generated by active learning _comparison_ mnist.ipynb. 
23 

24 

25 

26 

9734.7.3.2 Query by committee 


231n this section, we discuss how to apply uncertainty sampling to models, such as support vector 
2° machines (SVMs), that only return a point prediction rather than a probability distribution. The 
3 basic approach is to create an ensemble of diverse models, and to use disagreement between the 
31 model predictions as a form of uncertainty. (This can be useful even for probabilistic models, such as 
32 DNNs, since model uncertainty can often be larger than parametric uncertainty, as we discuss in the 
33 : : 

“section on deep ensembles, Section 17.3.9.) 


34 In more detail, suppose we have K ensemble members, and let c* be the predicted class from 


35 nember k on input æn. Let Une = D I eee a c) be the number of votes cast for class c, and 
zy dne = Une /C be the induced distribution. (A similar method can be used for regression models, 
gg Where we use the standard deviation of the prediction across the members.) We can then use margin 
~ sampling or entropy sampling with distribution qn. This approach is called query by committee 
1 (QBC) [SOS92], and can often out-perform vanilla uncertainty sampling with a single model, as we 
4 Show in Figure 34.16. 
42 
234.7.3.3 Information theoretic methods 
44 
45A natural acquisition strategy is to pick points whose labels will maximimally reduce our uncertainty 
46 about the model parameters w. This is known as the information gain criterion, and was first 
47 
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Figure 34.16: (a) Random Forest (RF) classifier applied to a 2-dimensional, 3-class data set. (b) Accuracy vs 
number of query points for margin sampling vs random sampling. We represent uncertainty either using a 
single RF (based on the predicted distribution over labels induced by the trees in the forest), or a committee 
containing an RF and a logistic regression model. Generated by active_ learning compare_class.ipynb. 


proposed in [Lin56]. It is defined as follows: 


a(x) Ê H (p(w|D)) — Epy|x,D) [H (p(w|D, æ, y))] (34.83) 


(Note that the first term is a constant wrt x, but we include it for later convenience.) This is 
equivalent to the expected change in the posterior over the parameters which is given by 


a! (x) = En(y\x,p) [Dex (p(w|D, x,y) || p(w|D))] (34.84) 


Using symmetry of the mutual information, we can rewrite Equation (34.83) as follows: 


a(x) = H (w|D) — E,cyje,p) [H (w|D, x, y)| (34.85) 
= I(w, y|D, x) (34.86) 
H (y|x, D) = Zp(w|D) [H (y|z, w, D)] (34.87) 


The advantage of this approach is that we now only have to reason about the uncertainty of the 
predictive distribution over outputs y, not over the parameters w. This approach is called Bayesian 
active learning by disagreement or BALD [Hou+12]. 

Equation (34.87) has an interesting interpretation. The first term prefers examples æ for which 
there is uncertainty in the predicted label. Just using this as a selection criterion is equivalent to 
uncertainty sampling, which we discussed above. However, this can have problems with examples 
which are inherently ambiguous or mislabeled. By adding the second term, we penalize such behavior, 
since we add a large negative weight to points whose predictive distribution is entropic even when 
we know the parameters. Thus we ignore aleatoric (intrinsic) uncertainty and focus on epistemic 
uncertainty. 


34.7.4 Batch active learning 


In many applications, we need to select a batch of unlabeled examples at once, since training a model 
on single examples is too slow. This is called batch active learning. The key challenge is that we 
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need to ensure the different queries that we request are diverse, so we maximize the information gain. 
Various methods for this problem have been devised; here we focus on the BatchBALD method of 
[KAG19], which extends the BALD method of Section 34.7.3.3. 


34.7.4.1 BatchBALD 
The naive way to extend the BALD score to a batch of b candidate query points is to define 


IO 100 IN ID Jo Ie low IN Ie 


B 
apatp({@1, Ey ,£B}, plw|D)) = QABALD (xı:B, p(w|D)) = 3 I(yi; w|x;,D) (34.88) 


However this may pick points that are quite similar in terms of their information content. In 
BatchBALD, we use joint conditional mutual information between the set of labels and the parameters: 


OBBALD(@1:B, p(w|D)) = I(yi:8; w|21:8,DP) = H(y1|£1:B, D) — Epwi) (H(y1.8|21:8, w, D)| 
(34.89) 


e le Ih ls la le lo Is IE lS 


19 To understand how this differs from BALD, we will use information diagrams for representing MI 
20in terms of Venn diagrams, as explained in Section 5.3.2. In particular, [Yeu91a] showed that we 
21can define a signed measure, u*, for discrete random variables x and y such that I(x; y) = u*(£ N y), 
2H(r,y) = p*(e Uy), Epy) H(zly)] = u* (x \ y), etc. Using this, we can interpret standard BALD as 
23the sum of the individual intersections, 5°, u*(y; N w), which double counts overlaps between the yi, 
24as shown in Figure 34.17(a). By contrast, BatchBALD takes overlap into account by computing 

25 


26 I(y1:B; w|t1:8,D) = p*(Uiys O w) = u* (Uiys) — w* (Uiyi \ w) (34.90) 
27 This is illustrated in Figure 34.17(b). From this, we can see that aBBALD < Q@patp. Indeed, one can 
28 show? 

22 B 

a, Høng, wlæ:s, D) = X Iyi waes, D) — TC(yi.5|21:8,D) (34.91) 
m i=1 

32 


33where TC is the total correlation (see Section 5.3.5.1). 


34 
3534.7.4.2 Optimizing BatchBALD 


at avoid the combinatorial explosion that arises from jointly scoring subsets of points, we can use a 
a greedy approximation for computing BatchBALD one point at a time. In particular, suppose at 
39 STEP n — 1 we already have a partial batch A,_1. The next point is chosen using 


40 @,= argmax apparp(An-1U {x}, p(w|D)) (34.92) 
41 LEDpooi\An-1 

We then add x, to An_1 to get An. Fortunately the BatchBALD acquisition function is submodular, 
#®as shown in [KAGI9]. Hence this greedy algorithm is within 1 — 1/e ~ 0.63 of optimal (see 
t Section 6.11.4.1). 

45 


465. See http://blog.blackhc .net/2022/07/kbald/ 
47 
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34.7. ACTIVE LEARNING 


yix yi|e1 
yo|x2 Y2|x2 
y3|e3 y3|x3 
(a) BALD (b) BatchBALD 


Figure 34.17: Intuition behind BALD and BatchBALD. Dyoo is an unlabelled dataset (from which £1» 
are taken) , Dtrain is the current training set, w is set of model parameters, p(y|z,w, Dtrain) are output 
predictions for data point x. BALD overestimates the joint mutual information whereas BatchBALD takes 
the overlap between variables into account. Areas contributing to the respective score are shown in grey, and 
areas that are double-counted in dark grey. From Figure 8 of [KAG19]. Used with kind permission of Andreas 
Kirsch. 


34.7.4.3 Computing BatchBALD 


Computing the joint (conditional) mutual information is intractable, so in this section, we discuss 
how to approximate it. For brevity we drop the conditioning on x and D. With this new notation, 
the objective becomes 


OBBALD(#1:B,P(w|D)) = H(y,.--, ys) — Ep) [H(q,---, yBlw)] (34.93) 


Note that the y; are conditionally independent given w, so H(y1,...,ye|w) = Sy H(y;|w). Hence 
we can approximate the second term with Monte Carlo: 


tpw) [H(yi,---.yalw)] © = som (yi| ts) (34.94) 


where Ùw, ~ p(w|D). 

The first term, H(y,..., ys), is a joint entropy, so is harder to compute. [KAG19] propose the 
following approximation, summing over all possible label sequences in the batch, and leveraging the 
fact that p(y) = Čp(w) [p(y|w)]: 


H(y1:8) = Epw)p(yi.2/ow) l- log p(yi:8|w)] (34.95) 
sS 
1 
x 5 (32 >x ĝJı:pl®s) ) log (2r p(¥1-B|Ws) (34.96) 
U1:B s=1 =1 


The sum over all possible labels sequences can be made more efficient by noting that p(y1.,|w) = 
P(Yn|w)p(Y1-n—1\w), so when we implement the greedy algorithm, we can incrementally update the 
probabilities, reusing previous computations. See [KAG19] for the details. 
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4 
q5 Figure 34.18: Three batches (each of size 4) queried from the MNIST pool by (a) BALD and (b) BatchBALD. 
ie Plot of accuracy vs number of points queried. BALD may select replicas of single informative datapoint 
17 while BatchBALD selects diverse points, thus increasing data efficiency. Generated by batch_ bald_ mnist.ipynb. 
18 
19 


99 84-7.4.4 Experimental comparison of BALD vs BatchBALD on MNIST 


21Tn this section, we show some experimental results applying BALD and BatchBALD to train a CNN 
22on the standard MNIST dataset. We use a batch size of 4, and approximate the posterior over 
23 parameters p(w|D) using MC dropout (Section 17.3.1). In Figure 34.18(a), we see that BALD selects 
24examples that are very similar to each other, whereas in Figure 34.18(b), we see that BatchBALD 
25 selects a greater diversity of points. In Figure 34.18(c), we see that BatchBALD results in more 


26 efficient learning than BALD, which in turn is more efficient than randomly sampling data. 
27 
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35 Reinforcement learning 


This chapter was co-authored with Lihong Li. 


35.1 Introduction 


Reinforcement learning or RL is a paradigm of learning where an agent sequentially interacts 
with an initially unknown environment. The interaction typically results in a trajectory, or multiple 
trajectories. Let T = (50, 40,10, $1, 41,11, $2,-.-,87) be a trajectory of length T, consisting of a 
sequence of states s+, actions a;, and rewards r;.! The goal of the agent is to optimize her action- 
selection policy, so that the discounted cumulative reward, Go pay ytri, is maximized for some 
given discount factor y € [0,1]. 

In general, Go is a random variable. We will focus on maximizing its expectation, inspired by the 
maximum expected utility principle (Section 34.1.1), but note other possibilities such as conditional 
value at risk? that can be more appropriate in risk-sensitive applications. 

We will focus on the Markov decision process, where the generative model for the trajectory T can 
be factored into single-step models. When these model parameters are known, solving for an optimal 
policy is called planning (see Section 34.6); otherwise, RL algorithms may be used to obtain an 
optimal policy from trajectories, a process called learning. 

In model-free RL, we try to learn the policy without explicitly representing and learning the 
models, but directly from the trajectories. In model-based RL, we first learn a model from the 
trajectories, and then use a planning algorithm on the learned model to solve for the policy. See 
Figure 35.1 for an overview. This chapter will introduce some of the key concepts and techniques, 
and will mostly follow the notation from [SB18]. More details can be found in textbooks such as 
[Sze10; SB18; Ber19; Aga+21la; Mey22; Aga+22], and reviews such as [WO12; Aru+17; FL+18; 
Lilg]. 


35.1.1 Overview of methods 


In this section, we give a brief overview of how to compute optimal policies when the MDP model 
is not known. Instead, the agent interacts with the environment and learns from the observed 


1. Note that the time starts at 0 here, while it starts at 1 when we discuss bandits (Section 34.4). Our choices of 
notation is to be consistent with conventions in respective literature. 

2. The conditional value at risk, or CVaR, is the expected reward conditioned on being in the worst 5% (say) of 
samples. See [Cho+15] for an example application in RL. 


1136 


REINFORCEMENT LEARNING 


Markov Decision Process P(s', S, a) 


Policy Iteration 79(S, a) 
Value Iteration V( 9) 
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Dynamic programming 
& Bellman optimality 


Nonlinear Dynamics 


== = f(x(t), u(t), t)dt 


Optimal Control & HJB 


-Figure 35.1: Overview of RL methods. Abbreviations: DQN = Deep Q network (Section 35.2.6); MPC = 
22 Model Predictive Control (Section 35.4); HJB = Hamilton Jacobi Bellman equation; TD = temporal difference 
22 learning (Section 35.2.2). Adapted from a slide by Steve Brunton. 

23 


24 

25 Method Functions learned On/Off Section 

26 SARSA Q(s, a) On Section 35.2.4 
27 Q-learning Q(s,a) Off Section 35.2.5 
28 REINFORCE t(als) On Section 35.3.2 
29 A2C m(als), V(s) On Section 35.3.3.1 
30 TRPO / PPO m(als), A(s, a) On Section 35.3.4 
31 DDPG a=7(s), Q(s,a) Off Section 35.3.5 
32 Soft actor-critic (als), Q(s,a) Off Section 35.6.1 
33 Model-based RL p(s’|s, a) Off Section 35.4 

4 

E Table 35.1: Summary of some popular methods for RL. On/off refers to on-policy vs off-policy methods. 
36 

37 

38 

39 


40 trajectories. This is the core focus of RL. We will go into more details into later sections, but first 
41 provide this roadmap. 
42 We may categorize RL methods by the quantity the agent represents and learns: value function, 
43 policy, and model; or by how actions are selected: on-policy (actions must be selected by the agent’s 
44current policy), and off-policy. Table 35.1 lists a few representative examples. More details are given 
45in the subsequent sections. We will also discuss at greater depth two important topics of off-policy 
46 learning and inference-based control in Sections 35.5 and 35.6. 
47 
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35.1. INTRODUCTION 


35.1.2 Value based methods 


In a value based method, we often try to learn the optimal Q-function from experience, and then 
derive a policy from it using Equation (34.72). Typically, a function approximator (e.g., a neural 
network), Qw, is used to represent the Q-function, which is trained iteratively. Given a transition 
(s,a,1r, 8’), we define the temporal difference (also called the TD error) as 


EE ymax Qw(s', a’) a Qw(s,@) 


Clearly, the expected TD error is the Bellman error evaluated at (s,a). Therefore, if Qu = Qx, the 
TD error is 0 on average by Bellman’s optimality equation. Otherwise, the error provides a signal for 
the agent to change w to make Qw(s,a) closer to R(s,a) + y maxa’ Qw(s’,a’). The update on Qw 
is based on a target that is computed using Qw. This kind of update is known as bootstrapping 
in RL, and should not be confused with the statistical bootstrap (Section 13.2.3.1). Value based 
methods such as Q-learning and SARSA are discussed in Section 35.2. 


35.1.3 Policy search methods 


In policy search, we try to directly maximize J(mg@) wrt the policy parameter 0. If J(z@) is 
differentiable wrt 0, we can use stochastic gradient ascent to optimize 0, which is known as policy 
gradient, as described in Section 35.3.1. The basic idea is to perform Monte Carlo rollouts, in 
which we sample trajectories by interacting with the environment, and then use the score function 
estimator (Section 6.5.3) to estimate VeJ(m@). Here, J(m@) is defined as an expectation whose 
distribution depends on 9, so it is invalid to swap V and E in computing the gradient, and the score 
function estimator can be used instead. An example of policy gradient is REINFORCE. 

Policy gradient methods have the advantage that they provably converge to a local optimum 
for many common policy classes, whereas Q-learning may diverge when approximation is used 
(Section 35.5.3). In addition, policy gradient methods can easily be applied to continuous action 
spaces, since they do not need to compute argmax, Q(s,a). Unfortunately, the score function 
estimator for VgeJ(7@) can have a very high variance, so the resulting method can converge slowly. 

One way to reduce the variance is to learn an approximate value function, V,,(s). and to use it 
as a baseline in the score function estimator. We can learn Vw(s) using one of the value function 
methods similar to Q-learning. Alternatively, we can learn an advantage function, Aw(s,a), and use 
it to estimate the gradient. These policy gradient variants are called actor critic methods, where 
the actor refers to the policy me and the critic refers to Vy or Aw. See Section 35.3.3 for details. 


35.1.4 Model-based RL 


Value-based methods, such as Q-learning, and policy search methods, such as policy gradient, can be 
very sample inefficient, which means they may need to interact with the environment many times 
before finding a good policy. If an agent has prior knowledge of the MDP model, it can be more 
sample efficient to first learn the model, and then compute an optimal (or near-optimal) policy of 
the model without having to interact with the environment any more. 

This approach is called model-based RL. The first step is to learn the MDP model including 
the pr(s’|s,a) and R(s,a) functions, e.g., using DNNs. Given a collection of (s,a,r,s’) tuples, such 
a model can be learned using standard supervised learning methods. The second step can be done 
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by running an RL algorithm on synthetic experiences generated from the model, or by running a 
planning algorithm on the model directly (Section 34.6). In practice, we often interleave the model 
learning and planning phases, so we can use the partially learned policy to decide what data to 
collect. We discuss model-based RL in more detail in Section 35.4. 


35.1.5 Exploration-exploitation tradeoff 


IO 100 IN ID Io Ie low IN Ie 


A fundamental problem in RL with unknown transition and reward models is to decide between 
1ochoosing actions that the agent knows will yield high reward, or choosing actions whose reward 
11is uncertain, but which may yield information that helps the agent get to parts of state-action 
12Space with even higher reward. This is called the exploration-exploitation tradeoff, which has 
13 been discussed in the simpler contextual bandit setting in Section 34.4. The literature on efficient 
14exploration is huge. In this section, we briefly describe several representative techniques. 

15 

"235.1.5.1 e-greedy 

18A common heuristic is to use an e-greedy policy Te, parameterized by € € [0,1]. In this case, we pick 
igthe greedy action wrt the current model, a; = argmax, Ri(sz, a) with probability 1 — €, and a random 
zoaction with probability e. This rule ensures the agent’s continual exploration of all state-action 
21combinations. Unfortunately, this heuristic can be shown to be suboptimal, since it explores every 
22action with at least a constant probability €/|A]. 

23 


2435.1.5.2 Boltzmann exploration 
25 


26 A source of inefficiency in the e-greedy rule is that exploration occurs uniformly over all actions. 
27The Boltzmann policy can be more efficient, by assigning higher probabilities to explore more 
28 promising actions: 


29 $ 

30 exp(Ri(s1,a)/T) 

— (als) = = 35.1 
F kale) Xa exp (s;,.0')/7) i 


33where T > 0 is a temperature parameter that controls how entropic the distribution is. As T gets 
34close to 0, 7, becomes close to a greedy policy. On the other hand, higher values of 7 will make 
357(a|s) more uniform, and encourage more exploration. Its action selection probabilities can be much 
36“smoother” with respect to changes in the reward estimates than e-greedy, as illustrated in Table 35.2. 
37 


“535.1.5.3 Upper confidence bounds and Thompson sampling 


40The upper confidence bound (UCB) (Section 34.4.5) and Thompson sampling (Section 34.4.6) 
4lapproaches may also be extended to MDPs. In contrast to the contextual bandit case, where the 
42only uncertainty is in the reward function, here we must also take into account uncertainty in the 
43 transition probabilities. 

44 As in the bandit case, the UCB approach requires to estimate an upper confidence bound for all 
45 actions’ @-values in the current state, and then take the action with the highest UCB score. One way 
46to obtain UCBs of the @-values is to use count-based exploration, where we learn the optimal 
47 
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35.1. INTRODUCTION 


R(s,a1) R(s,a2) m(alsi) me(alse) m-(alsi) 7,(alse) 


1.00 9.00 0.05 0.95 0.00 1.00 
4.00 6.00 0.05 0.95 0.12 0.88 
4.90 5.10 0.05 0.95 0.45 0.55 
5.05 4.95 0.95 0.05 0.53 0.48 
7.00 3.00 0.95 0.05 0.98 0.02 
8.00 2.00 0.95 0.05 1.00 0.00 


Table 35.2: Comparison of e-greedy policy (with € = 0.1) and Boltzmann policy (with Tr = 1) for a simple 
MDP with 6 states and 2 actions. Adapted from Table 4.1 of [GK19]. 


Q-function with an exploration bonus added to the reward in a transition (s, a,r, s’): 


F=rt+a/J/Nsa (35.2) 


where Nsa is the number of times action a has been taken in state s, and a > 0 is a weighting term 
that controls the degree of exploration. This is the approach taken by the MBIE-EB method [SL08] 
for finite-state MDPs, and in the generalization to continuous-state MDPs through the use of 
hashing [Bel+-16]. Other approaches also explicitly maintain uncertainty in state transition proba- 
bilities, and use that information to obtain UCBs. Examples are MBIE [SL08], UCRL2 [JOA10], 
UCBVI [AOM17], among many others. 

Thompson sampling can be similarly adapted, by maintaining the posterior distribution of the 
reward and transition model parameters. In finite-state MDPs, for example, the transition model is a 
categorical distribution conditioned on the state. We may use the conjugate prior of Dirichlet 
distributions (Section 3.2) for the transition model, so that the posterior distribution can be 
conveniently computed and sampled from. More details on this approach are found in [Rus+18]. 

Both UCB and Thompson sampling methods have been shown to yield efficient exploration with 
provably strong regret bounds (Section 34.4.7) [JOA10], or related PAC bounds [SLL09; DLB17], 
often under necessary assumptions such as finiteness of the MDPs. In practice, these methods may 
be combined with function approximation like neural networks and implemented approximately. 


35.1.5.4 Optimal solution using Bayes-adaptive MDPs 


The Bayes optimal solution to the exploration-exploitation tradeoff can be computed by formulating 
the problem as a special kind of POMDP known as a Bayes-adaptive MDP or BAMDP [Duf02]. 
This extends the Gittins index approach in Section 34.4.4 to the MDP setting. 

In particular, a BAMDP has a belief state space, 6, representing uncertainty about the reward 
model pr(r|s, a, s") and transition model pr(s’|s,a). The transition model on this augmented MDP 
can be written as follows: 


TH (Se41, be41|St, bt, at, rt) = Tt (Se41|8t, at, be) (be41| Se, At, Tt, $441) (35.3) 
= Ep, [T (st+1|5t; a¢)] x I(bi+1 = p(R, T|hi+1)) (35.4) 


where Ep, [T (st+1|5+, at)] is the posterior predictive distribution over next states, and p(R,T|hi+1) is 
the new belief state given hi1 = (S1:t+1; @1:t+1;T1:t+1), which can be computed using Bayes rule. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1140 


Similarly, the reward function for the augmented MDP is given by 


R? (rl se, be, at, Se41, bt 1H) = Tbsp [R(st, at, St+1)] (35.5) 


For small problems, we can solve the resulting augmented MDP optimally. However, in general 
this is computationally intractable. [Gha+15] surveys many methods to solve it more efficiently. 
For example, [KN09] develop an algorithm that behaves similarly to Bayes optimal policies, except 
in a provably small number of steps; [GSD13] propose an approximate method based on Monte 

o Carlo rollouts. More recently, [Zin+20] propose an approximate method based on meta-learning 
;, (Section 19.6.4), in which they train a (model-free) policy for multiple related tasks. Each task is 
„represented by a task embedding vector m, which is inferred from h, using a VAE (Section 21.2). 
-a Lhe posterior p(m|h;) is used as a proxy for the belief state b, and the policy is trained to perform 
well given s; and b. At test time, the policy is applied to the incrementally computed belief state; 
this allows the method to infer what kind of task this is, and then to use a pre-trained policy to 
quickly solve it. 
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1835.2 Value-based RL 
19 
20TIn this section, we assume the agent has access to samples from pr and pr by interacting with the 
2lenvironment. We will show how to use these samples to learn optimal Q-functions from which we 
22can derive optimal policies. 

23 


~ 35.2.1 Monte Carlo RL 


26 Recall that Q,,(s, a) = E[G:|s; = s, a; = a] for any t. A simple way to estimate this is to take action 
27a, and then sample the rest of the trajectory according to 7, and then compute the average sum of 
28 discounted rewards. The trajectory ends when we reach a terminal state, if the task is episodic, or 
29when the discount factor 4% becomes negligibly small, whichever occurs first. This is the Monte 
30Carlo estimation of the value function. 

31 We can use this technique together with policy iteration (Section 34.6.2) to learn an optimal policy. 
32 Specifically, at iteration k, we compute a new, improved policy using 7,41(s) = argmax, Q;(s, a), 
33where Q, is approximated using MC estimation. This update can be applied to all the states visited 
34on the sampled trajectory. This overall technique is called Monte Carlo control. 

35 To ensure this method converges to the optimal policy, we need to collect data for every (state, 
36 action) pair, at least in the tabular case, since there is no generalization across different values of 
37Q(s,a). One way to achieve this is to use an e-greedy policy. Since this is an on-policy algorithm, 
38the resulting method will converge to the optimal e-soft policy, as opposed to the optimal policy. It 
39is possible to use importance sampling to estimate the value function for the optimal policy, even if 


40 actions are chosen according to the e-greedy policy. However, it is simpler to just gradually reduce e. 
41 


T 35.2.2 Temporal difference (TD) learning 


44 The Monte Carlo (MC) method in Section 35.2.1 results in an estimator for Q,,(s,a) with very high 
45variance, since it has to unroll many trajectories, whose returns are a sum of many random rewards 
46 generated by stochastic state transitions. In addition, it is limited to episodic tasks (or finite horizon 
47 
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35.2. VALUE-BASED RL 


Monte-Carlo Temporal-Difference Dynamic Programming 
V(St) H V(St) +a (Gi — V(Sr)) V(St) — V(Se) +a (Reva + 9V(Se+1) — V(Se)) V(Se) © Ex [Resa +9V(Sesr)] 
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Figure 85.2: Backup diagrams of V (s+) for Monte Carlo, temporal difference, and dynamic programming 
updates of the state-value function. Used with kind permission of Andy Barto. 


a 


truncation of continuing tasks), since it must unroll to the end of the episode before each update 
step, to ensure it reliably estimates the long term return. 

In this section, we discuss a more efficient technique called temporal difference or TD learning 
[Sut88]. The basic idea is to incrementally reduce the Bellman error for sampled states or state-actions, 
based on transitions instead of a long trajectory. More precisely, suppose we are to learn the value 
function V, for a fixed policy 7. Given a state transition (s,a,r, s’) where a ~ 7(s), we change the 
estimate V(s) so that it moves toward the bootstrapping target (Section 35.1.2) 


V (st) — V(se) +n [ri + WV (st41) — V (s+)] (35.6) 


where n is the learning rate. The term multiplied by 7 above is known as the TD error. A more 
general form of TD update for parametric value function representations is 


w | w+ n [rt =F Vw (St41) = Vw (s+)] V w Vw (St) (35.7) 


of which Equation (35.6) is a special case. The TD update rule for learning Q, is similar. 

It can be shown that TD learning in the tabular case, Equation (35.6), converges to the correct 
value function, under proper conditions [Ber19]. However, it may diverge when approximation is 
used (Equation (35.7)), an issue we will discuss further in Section 35.5.3. 

The potential divergence of TD is also consistent with the fact that Equation (35.7) is not SGD 
(Section 6.3) on any objective function, despite a very similar form. Instead, it is an example of 
bootstrapping, in which the estimate, Vw(s+), is updated to approach a target, re + YVw(St+1), 
which is defined by the value function estimate itself. This idea is shared by DP methods like value 
iteration, although they rely on the complete MDP model to compute an exact Bellman backup. In 
contrast, TD learning can be viewed as using sampled transitions to approximate such backups. An 
example of non-bootstrapping approach is the Monte Carlo estimation in the previous section. It 
samples a complete trajectory, rather than individual transitions, to perform an update, and is often 
much less efficient. Figure 35.2 illustrates the difference between MC, TD, and DP. 


35.2.3 TD learning with eligibility traces 


A key difference between TD and MC is the way they estimate returns. Given a trajectory T = 
(s0, 40,70, $1,---,;87), TD estimates the return from state s; by one-step lookahead, Gyt41 = rt + 
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yp igure 35.8: The backup diagram for TD(A). Standard TD learning corresponds to A = 0, and standard MC 
qg earning corresponds to A = 1. From Figure 12.1 of [SB18]. Used with kind permission of Richard Sutton. 
17 

8 
19 YV (st+1), where the return from time t + 1 is replaced by its value function estimate. In contrast, 


99 MC waits until the end of the episode or until T is large enough, then uses the estimate Gir = 
og Te tea tee + 77 -'-lpp_y. It is possible to interpolate between these by performing an n-step 
9, rollout, and then using the value function to approximate the return for the rest of the trajectory, 


93 Similar to heuristic search (Section 35.4.1.1). That is, we can use the n-step estimate 


24 


25 Grten = Tt H Yrer te FY regen + YV (St4n) (35.8) 


26 The corresponding n-step version of the TD update becomes 
27 


28 V (s+) {| V (s+) +n Gitin = V (s+)] (35.9) 
29 

30 Rather than picking a specific lookahead value, n, we can take a weighted average of all possible 
31values, with a single parameter A € [0,1], by using 


32 

33 = 

a RS-A) J A Git (35.10) 
Pra n=1 

35 


36 This is called the \-return. The coefficient of 1 — A = (1++A?+---)7! in the front ensures this 
37is a convex combination of n-step returns. See Figure 35.3 for an illustration. 

38 An important benefit of using the geometric weighting in Equation (35.10) is that the corresponding 
39TD learning update can be efficiently implemented, through the use of eligibility traces, even 
40though GÀ is a sum of infinitely many terms. The method is called TD(A), and can be combined 
41with many algorithms to be studied in the rest of the chapter. See [SB18] for a detailed discussion. 


42 


35.2.4 SARSA: on-policy TD control 


45 TD learning is for policy evaluation, as it estimates the value function for a fixed policy. In order 
46to find an optimal policy, we may use the algorithm as a building block inside generalized policy 
47 
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35.2. VALUE-BASED RL 


iteration (Section 34.6.2). In this case, it is more convenient to work with the action-value function, 
Q, and a policy m that is greedy with respect to Q. The agent follows 7 in every step to choose 
actions, and upon a transition (s,a,r,s’) the TD update rule is 


Q(s, a) — Q(s,a) +n [r + 7Q(s',a’) — Q(s, a)] (35.11) 


where a’ ~ 7(s’) is the action the agent will take in state s’. After Q is updated (for policy evaluation), 
m also changes accordingly as it is greedy with respect to Q (for policy improvement). This algorithm, 
first proposed by [RN94], was further studied and renamed to SARSA by [Sut96]; the name comes 
from its update rule that involves an augmented transition (s,a,1, s’, a’). 

In order for SARSA to converge to Q,, every state-action pair must be visited infinitely often, at 
least in the tabular case, since the algorithm only updates Q(s,a) for (s,a) that it visits. One way 
to ensure this condition is to use a “greedy in the limit with infinite exploration” (GLIE) policy. An 
example is the e-greedy policy, with e€ vanishing to 0 gradually. It can be shown that SARSA with a 
GLIE policy will converge to Q, and 7, [Sin+00]. 


35.2.5 Q-learning: off-policy TD control 


SARSA is an on-policy algorithm, which means it learns the Q-function for the policy it is currently 
using, which is typically not the optimal policy (except in the limit for a GLIE policy). However, 
with a simple modification, we can convert this to an off-policy algorithm that learns Q,, even if a 
suboptimal policy is used to choose actions. 

The idea is to replace the sampled next action a’ ~ 7(s’) in Equation (35.11) with a greedy action 
in s’: a’ = argmax, Q(s’, b). This results in the following update when a transition (s, a, r, s’) happens 


Q(s,a) — Q(s,a) +n |r + ymax Q(s', b) — Q(s, a) (35.12) 


This is the update rule of Q-learning for the tabular case [WD92]. The extension to work with 
function approximation can be done in a way similar to Equation (35.7). Since it is off-policy, the 
method can use (s,a,1, 8’) triples coming from any data source, such as older versions of the policy, 
or log data from an existing (non-RL) system. If every state-action pair is visited infinitely often, the 
algorithm provably converges to Q, in the tabular case, with properly decayed learning rates [Ber19]. 
Algorithm 46 gives a vanilla implementation of Q-learning with ¢-greedy exploration. 


35.2.5.1 Example 


Figure 35.4 gives an example of Q-learning applied to the simple 1d grid world from Figure 34.11, 
using y = 0.9. We show the @-functon at the start and end of each episode, after performing actions 
chosen by an e-greedy policy. We initialize Q(s,a) = 0 for all entries, and use a step size of 7 = 1, so 
the update becomes Q..(s,a) = r + yQ.(s’, ax), where a, =} for all states. 


35.2.5.2 Double Q-learning 


Standard Q-learning suffers from a problem known as the optimizer’s curse [SW0O6], or the maxi- 
mization bias. The problem refers to the simple statistical inequality, E [maxa Xa] > maxa E [Xa], 
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1 

2 Algorithm 46: Q-learning with e-greedy exploration 

3 4 Initialize value function parameters w 

ta repeat 

53 Sample starting state s of new episode 

a4 repeat 

7 

argma: b ith probability 1 — 
8 5 Sample action a = remax, Qw(s, ); Na Pe 2 y i 
9 random action, with probability e€ 

10 6 Observe state s’, reward r 

117 Compute the TD error: 6 = r + y maxw Qw(s’, a’) — Qw(s, a) 
128 w + w + nNnôVwQuls, a) 

139 ses! 

140 until state s is terminal 


141 until converged 


99 scores {Xa}, we might pick a wrong action just because random noise makes it appealing. 

9, Figure 35.5 gives a simple example of how this can happen in an MDP. The start state is A. 
99 The right action gives a reward 0 and terminates the episode. The left action also gives a reward 
930f 0, but then enters state B, from which there are many possible actions, with rewards drawn 
g4from N(—0.1,1.0). Thus the expected return for any trajectory starting with the left action is 
95 0.1, making it suboptimal. Nevertheless, the RL algorithm may pick the left action due to the 
9gmaximization bias making B appear to have a positive value. 

97 One solution to avoid the maximization bias is to use two separate @-functions, Qı and Q2, one 
9g for selecting the greedy action, and the other for estimating the corresponding Q-value. In particular, 


g Upon seeing a transition (s,a,r,s’), we perform the following update 

30 

31 Qi(s,a) + Qi(s,a) +7 |r + ¥Qo(s’, argmax Q1(s’,a’)) — Qi(s,a) (35.13) 
32 a 

33and may repeat the same update but with the roles of Qı and Qz swapped. This technique is 
34called double Q-learning |Has10]. Figure 35.5 shows the benefits of the algorithm over standard 


35 Q-learning in a toy problem. 
36 


~ 35-2-6 Deep Q-network (DQN) 


39 When function approximation is used, Q-learning may be hard to use in practice due to instability 
40 problems. Here, we will describe two important heuristics, popularized by the deep Q-network or 
41DQN work [Mni+15], which was able to train agents to outperform humans at playing Atari games, 
42using CNN-structured Q-networks. 

43 The first technique, originally proposed in [Lin92], is to leverage an experience replace buffer, 
44which stores the most recent (s,a,7,s’) transition tuples. In contrast to standard Q-learning which 
45updates the Q-function when a new transition occurs, the DQN agent also performs additional 
46updates using transitions sampled from the buffer. This modification has two advantages. First, it 
47 
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Figure 35.4: Illustration of Q learning for the 1d grid world in Figure 34.11 using e-greedy exploration. At the 
end of episode 1, we make a transition from S3 to Sr2 and get a reward of r = 1, so we estimate Q(S3,}) = 1. 
In episode 2, we make a transition from S2 to S3, so S2 gets incremented by yQ(S3,1) = 0.9. Adapted from 


Figure 3.3 of [GK19]. 


improves data efficiency as every transition can be used multiple times. Second, it improves stability 


in training, by reducing the correlation of the data samples that the network is trained on. 


The second idea to improve stability is to regress the Q-network to a “frozen” target network 
computed at an earlier iteration, rather than trying to chase a constantly moving target. Specifically, 
we maintain an extra, frozen copy of the Q-network, Q,,-, of the same structure as Qw. This new 
Q-network is to compute bootstrapping targets for training Qw, in which the loss function is 


Le (w) = Us ,a,r,s')} vU (D) l(r FY maa Qw- (s, a’) = Qu (s, a))"] 
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1 
_ Figure 35.5: Comparison of Q-learning and double Q-learning on a simple episodic MDP using e-greedy action 
= selection with e = 0.1. The initial state is A, and squares denote absorbing states. The data are averaged 
B over 10,000 runs. From Figure 6.5 of [SB18]. Used with kind permission of Rich Sutton. 

4 
15 
16 
I where U(D) is a uniform distribution over the replay buffer D. We then periodically set w~ + w, 
48usually after a few episodes. This approach is an instance of fitted value iteration [SB18]. 
19 


Various improvements to DQN have been proposed. One is double DQN [HGS16], which uses 
20 the double learning technique (Section 35.2.5.2) to remove the maximization bias. The second is to 
2l replace the uniform distribution in Equation (35.14) with one that favors more important transition 
22tuples, resulting in the use of prioritized experience replay [Sch+16a]. For example, we can 
23 sample transitions from D with probability p(s,a,r,s’) œ (|6| + £)”, where ô is the corresponding 
24TD error (under the current Q-function), € > 0 a hyperparameter to ensure every experience is 
2° chosen with nonzero probability, and 7 > 0 controls the “inverse temperature” of the distribution 
26(so ņ = 0 corresponds to uniform sampling). The third is to learn a value function Vw and an 
27 advantage function Aw, with shared parameter w, instead of learning Qw. The resulting dueling 
2DQN [Wan+16] is shown to be more sample efficient, especially when there are many actions with 
29 similar Q-values. 

30 The rainbow method [Hes+18] combines all three improvements, as well as others, including 
31 multi-step returns (Section 35.2.3), distributional RL [BDM17| (which predicts the distribution 
“of returns, not just the expected return), and noisy nets [For+18b] (which adds random noise to 


33the network weights to encourage exploration). It produces state-of-the-art results on the Atari 


34 benchmark. 

35 

36 

3735.3 Policy-based RL 
38 


39In the previous section, we considered methods that estimate the action-value function, Q(s,a), from 
4owhich we derive a policy, which may be greedy or softmax. However, these methods have three main 
41 disadvantages: (1) they can be difficult to apply to continuous action spaces; (2) they may diverge if 
42function approximation is used; and (3) the training of Q, often based on TD-style updates, is not 
43 directly related to the expected return garnered by the learned policy. 

44 In this section, we discuss policy search methods, which directly optimize the parameters of the 
45 policy so as to maximize its expected return. However, we will see that these methods often benefit 
46 from estimating a value or advantage function to reduce the variance in the policy search process. 
47 
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35.3. POLICY-BASED RL 


35.3.1 The policy gradient theorem 


We start by defining the objective function for policy learning, and then derive its gradient. We 
consider the episodic case. A similar result can be derived for the continuing case with the average 
reward criterion [SB18, Sec 13.6]. 

We define the objective to be the expected return of a policy, which we aim to maximize: 


J(T) = APO, T [Go] = Uo (s0) [Vz (so)] = 9 (so) (a0ls0) [Qr (50, a0)] (35.15) 


We consider policies mọ parameterized by 0, and compute the gradient of Equation (35.15) wrt 0: 


VoJ (Te) = Epo(so) 


Vo (= clos} Qn(0-%)) (35.16) 


XC Vrra(a0|80) Qe ($0, a0) 


ao 


+ Čpo(so)Te(aolso) [Vo Qro (so, ao)] (35.17) 


= Epo (so) 


Now we calculate the term VoQre (so, ao): 


VoQro (so, ao) = Ve [R(so, ao) tY Spr (si|80,00) [Vro (s1)]] =7Ve ‘pr (si|80,a0) [Vro (s1)] (35.18) 


The right-hand side above is in a form similar to Ve J (rte). Repeating the same steps as before gives 


VeoJ(Te) = Dy ne (8) 


E Vorlasnu() (35.19) 

t=0 a 

= = pe, (s) E Vorolssda() (35.20) 
1 

= Tay Bets rolls) [Vo log To (a|s) Qro (s, a)] (35.21) 


where p(s) is the probability of visiting s in time t if we start with so ~ po and follow 7, and p3 (s) = 
(1— 7) 2729 7P:(s) is the normalized discounted state visitation distribution. Equation (35.21) is 
known as the policy gradient theorem [Sut+99]. 

In practice, estimating the policy gradient using Equation (35.21) can have a high variance. A 
baseline b(s) can be used for variance reduction (Section 6.5.3.1): 


Vo (ro) = 1 Epe rolele) [Vo log 70(als)(Qre(8s a) — H(3))] (35.22) 


A common choice for the baseline is b(s) = Vre (s). We will discuss how to estimate it below. 
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1 

2 35.3.2 REINFORCE 

3 One way to apply the policy gradient theorem to optimize a policy is to use stochastic gradient 
4 ascent. Suppose T = (s0, a0, T0, S1,- --, ST) is a trajectory with so ~ po and mg. Then, 

5 

5 TA 

: VoJ(me) = Toy Eom Orolals) [Vo log 79(als)Qro(s, @)] (35.23) 
s si 

9 ~ X. 7 GiVo log ro (alst) (35.24) 
10 i= 

1 where the return G; is defined in Equation (34.64), and the factor 7 is due to the definition of p% 
/2where the state at time t is discounted. 

13 We can use a baseline in the gradient estimate to get the following update rule: 

4 pa 

ie OT 0+nY_ y (Gi = b(st))Vo log Tolars) (35.25) 
= t=0 

17 


ig This is called the REINFORCE algorithm [Wil92].* The udpate equation can be interepreted as 
19follows: we compute the sum of discounted future rewards induced by a trajectory, compared to a 
20 baseline, and if this is positive, we increase 0 so as to make this trajectory more likely, otherwise we 
21 decrease 0. Thus, we reinforce good behaviors, and reduce the chances of generating bad ones. 

22 We can use a constant (state-independent) baseline, or we can use a state-dependent baseline, b(s+) 
23 to further lower the variance. A natural choice is to use an estimated value function, Vw(s), which 
24can be learned, e.g., with MC. Algorithm 47 gives the pseudo code where stochastic gradient updates 
25are used with separate learning rates. 

26 

27 Algorithm 47: REINFORCE with value function baseline 


281 Initialize policy parameters 0, baseline parameters w 

292 repeat 

303 Sample an episode T = (sọ, a0, To, $1,---; ST) using Tø 

314 Compute G; for all t € {0,1,...,7 — 1} using Equation (34.64) 
325 for t=0,1,...,T — 1 do 


33 6 ô = Gt — Vw (st) // scalar error 
347 w 4 wt nw0V wVwlst) 

35 g 0+ 0 + ney'dVo log To (arlst) 
36 , 

379 until converged 

38 

39 

4935.3.3 Actor-critic methods 

41 


a2 An actor-critic method [BSA83] uses the policy gradient method, but where the expected return is 
43 estimated using temporal difference learning of a value function instead of MC rollouts. The term 


——— 
~~ 3. The term “REINFORCE?” is an acronym for “REward Increment = nonnegative Factor x Offset Reinforcement x 
45 Characteristic Eligibility”. The phrase “Characteristic eligibility” refers to the V log 7@(at|sz) term; the phrase “offset 
46 reinforcement” refers to the G — b(s¢) term; and the phrase “nonnegative factor” refers to the learning rate 7 of SGD. 


47 
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35.3. POLICY-BASED RL 


“actor” refers to the policy, and the term “critic” refers to the value function. The use of bootstrapping 
in TD updates allows more efficient learning of the value function compared to MC. In addition, it 
allows us to develop a fully online, incremental algorithm, that does not need to wait until the end of 
the trajectory before updating the parameters (as in Algorithm 47). 

Concretely, consider the use of the one-step TD(0) method to estimate the return in the episodic 
csae, i.e., we replace Gi with Grt+1 = re +YVw(st+1). If we use Vu(s+) as a baseline, the REINFORCE 
update in Equation (35.25) becomes 


T-1 

0 0 +Y 7 (Givi — Vw(81)) Vo log 79 (ae\ 50) (35.26) 
t=0 
T-1 

=0+n 7 (rt + Vw (St41) — Vw (s1)) Vo log me (az| 8+) (35.27) 
t=0 


35.3.3.1 A2C and A3C 


Note that ri41 + YVw(si+1) — Vu (s+) is a single sample approximation to the advantage function 
A(st,@4) = Q(st, at) — V (s+). This method is therefore called advantage actor critic or A2C 
(Algorithm 48). If we run the actors in parallel and asychronously update their shared parameters, 
the method is called asychrononous advantage actor critic or A3C [Mni+16]. 


Algorithm 48: Advantage actor critic (A2C) algorithm 


1 Initialize actor parameters 0, critic parameters w 
2 repeat 
3 Sample starting state sp of a new episode 
for t = 0,1,2,...do 
Sample action a; ~ 7(-|54) 
Observe next state s;4; and reward rz 
Ô = ri + Vw (Si41) — Va (s) 
Ww 4 wt nwOV wVw (St) 
0+ 0 + ney'Vo log To (ar|st) 


C©CMN DA À 


34 10 until converged 


35.3.3.2 Eligibility traces 


In A2C, we use a single step rollout, and then use the value function in order to approximate the 
expected return for the trajectory. More generally, we can use the n-step estimate 


Gittin = Tt bares, HY rege He HY tnt +Y Vw (Stn) (35.28) 
and obtain an n-step advantage estimate as follows: 
AM) (st, at) = Gettin a Vw (st) (35.29) 
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1 

2 The n steps of actual rewards are an unbiased sample, but have high variance. By contrast, 
3 Vw(St+n+1) has lower variance, but is biased. By changing n, we can control the bias-variance 
4 tradeoff. Instead of using a single value of n, we can take an weighted average, with weight 
5 proportional to A” for A”) (st, at), as in TD(A). The average can be shown to be equivalent to 

8 co 

: AOD (se, a) = So (A) Se 41 (35.30) 
R £=0 

9 


10where 6, = re + YVw(st+1) — Vw(s+) is the TD error at time t. Here,  € [0,1] is a parameter 
lithat controls the bias-variance tradeoff: larger values decrease the bias but increase the variance, 
12as in TD(A). We can implement Equation (35.30) efficiently using eligibility traces, as shown in 
13 Algorithm 49, as an example of generalized advantage estimation (GAE) [Sch+16b]. See [SB18, 
14Ch.12] for further details. 


5 
s Algorithm 49: Actor critic with eligibility traces 
j31 Initialize actor parameters 6, critic parameters w 
92 repeat 
203 Initialize eligibility trace vectors: zg + 0, Zw <0 
214 Sample starting state so of a new episode 
225 for t = 0,1,2,... do 
23 ê Sample action a; ~ Te(-|s+) 
baT Observe state s,,; and reward r; 
25 8 Compute the TD error: 6 = r; + YVw(St+1) — Vw (St) 
ae Zw — YAwZw + Vw Vw(s) 
9710 zo + y\oz0 +: ¥' Vo log 76 (ar|st) 
agit w + wt qwidZw 
agi? | 9<- 6+ NedZe@ 
343 until converged 
31 
32 
33 


3435.3.4 Bound optimization methods 


Pin policy gradient methods, the objective J(@) does not necessarily increase monotonically, but 
— rather can collapse especially if the learning rate is not small enough. We now describe methods that 
gg guarantee monotonic improvement, similar to bound optimization algorithms (Section 6.6). 

We start with a useful fact that relate the policy values of two arbitrary policies [KL02]: 


39 

40 j Me a : 

41 J(r ) = J (1) = ty “p% (s) [ 47! (a|s) [A,(s, a)]] (35.31) 
42 


43 where m can be interpreted as the current policy during policy optimization, and 7’ a candidate new 
44policy (such as the greedy policy wrt Qr). As in the policy improvement theorem (Section 34.6.2), 
45if Er(ajs) [Ax(s,@)] > 0 for all s, then J(n’) > J(m). However, we cannot ensure this condition to 
46hold when function approximation is used, as such a uniformly improving policy 7’ may not be 
47 
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35.3. POLICY-BASED RL 


representable by our parametric family, {7@}eco. Therefore, nonnegativity of Equation (35.31) is 
not easy to ensure, when we do not have a direct way to sample states from p°. 

One way to ensure monotonic improvement of J is to improve the policy conservatively. Define 
Tg = OT’ + (1— 0)r for 8 € [0,1]. It follows from the policy gradient theorem (Equation (35.21), with 
0 = [0]) that J(m9) — J (r) = OL(r') + O(67), where 


1 
i ee 


1 (als) 


L(x’) = Ups (s)m(als) Ea (35.32) 


1 

ipa (s) [Er (als) [Ax(s,@)]] = T 
In the above, we have switched the state distribution from p% in Equation (35.31) to p% , while at 
the same time introducing a higher order residual term of O(8?). The linear term, L(7’), can be 
estimated and optimized based on episodes sampled by m. The higher order term can be bounded in 
various ways, resulting in different lower bounds of J (mọ) — J(7). We can then optimize 0 to make 
sure this lower bound is positive, which would imply J(7») — J(z) > 0. In conservative policy 
iteration [KL02], the following (slightly simplified) lower bound is used 


JOP! (ag) & J(r) + OL(n') — Ng (35.33) 
a=? 
where € = maxs |E7/(q\s) [An(s, @)] |. 

This idea can be generalized to policies beyond those in the form of mọ, where the condition 
of a small enough @ is replaced by a small enough divergence between 7’ and m. In safe policy 
iteration [Pir+13], the divergence is the maximum total variation, while in trust region policy 
optimization (TRPO) [Sch+15b], the divergence is the maximum KL-divergence. In the latter 
case, 7’ may be found by optimizing the following lower bound 


FPRPOCR!) & Im) +L) — a max Dac, (7(s) || 7’(s)) (35.34) 
where € = Maxs a |Ar(s,a)|. 

In practice, the above update rule can be overly conservative, and approximations are used. 
[Sch+15b] propose a version that implements two ideas: one is to replace the point-wise maximum 
KL-divergence by some average KL-divergence (usually averaged over po ); the second is to maximize 
the first two terms in Equation (35.34), with 7’ lying in a KL-ball centered at m. That is, we solve 


argmax L(n’) s.t. Epes) [Dx (7(s) || m (s))] < 6 (35.35) 


for some threshold ô > 0. 

In Section 6.4.2.1, we show that the trust region method, using a KL penalty at each step, is 
equivalent to natural gradient descent (see e.g., [Kak02; PS08b]). This is important, because a step 
of size ņ in parameter space does not always correspond to a step of size 7 in the policy space: 


de (91, 02) = do (02, 03) $ dr(Te, To) = dr (75,703) (35.36) 


where dọ(01,02) = ||01 — || is the Euclidean distance, and d,(7,72) = Dgu (mı || 72) the KL 
distance. In other words, the effect on the policy of any given change to the parameters depends 
on where we are in parameter space. This is taken into account by the natural gradient method, 
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resulting in faster and more robust optimization. The natural policy gradient can be approximated 
using the KFAC method (Section 6.4.4), as done in [Wu+17]. 

Other than TRPO, another approach inspired by Equation (35.34) is to use the KL-divergence as 
a penalty term, replacing the factor 2ey/(1 — y)? by a tuning parameter. However, it often works 
better, and is simpler, by using the following clipped objective, which results in the proximal policy 
optimization or PPO method [Sch+17]|: 


PPO (a!) ê = PE |r (aa) AG a) (35.37) 


IO 100 IN ID Jo e low IN Ie 


here Kke(x) £ clip(a, 1 —,1+ €) ensures |«(x) — 1| < e. This method can be modified to ensure 
onotonic improvement as discussed in [WHT19], making it a true bound optimization method. 


5 4 


35.3.5 Deterministic policy gradient methods 


In this section, we consider the case of a deterministic policy, that predicts a unique action for each 
state, so a, = f4g(Sz), rather than a; ~ 7@(s;). We assume the states and actions are continuous, and 
define the objective as 


mo Te) © 1 Eng c) [R(s, Hols) (35.38) 


The deterministic policy gradient theorem [Sil+14] provides a way to compute the gradient: 


1 
ot VOI (He) = Eo, (s) [VoQuo(s, 10(s)) (35.39) 


25 1 


a = T. Spee (8) [Votie(s) VaQuo(S, 2) la=pe(s)| (35.40) 


~ Where Vote(s) is the M x N Jacobian matrix, and M and N are the dimensions of A and 9, 
~respectively. For stochastic policies of the form 7@(a|s) = ue(s) + noise, the standard policy gradient 
zo (neorem reduces to the above form as the noise level goes to zero. 

— Note that the gradient estimate in Equation (35.40) integrates over the states but not over the 
gp actions, which helps reduce the variance in gradient estimation from sampled trajectories. However, 
— since the deterministic policy does not do any exploration, we need to use an off-policy method, that 
~ collects data from a stochastic behavior policy 6, whose stationary state distribution is DB: The 


4 original objective, J (ue), is approximated by the following: 


36 Jy (ue) = Ope (s) [Vite (s)] = ope (s) [Quo (s, uo(s))] (35.41) 
37 

jg With the off-policy deterministic policy gradient approximated by (see also Section 35.5.1.2) 

39 Vo Jo(ue) © Epe(s) [Vo [Que (s, H0(s))]] = Epee(s) [Vore(s) VaQue($,)|a=p9(s) 45] (35.42) 
40 


4, Where we have a dropped a term that depends on VgQ,,(s,a) and is hard to estimate [Sil+14]. 
42 To apply Equation (35.42), we may learn Qw ~ Que with TD, giving rise to the following updates: 


B Ô = ri + YQw(St+1, Ho(St41)) — Qw (St, at) (35.43) 
44 
ae Wty — Wt + TwOV wQw (St; at) (35.44) 
a6 O41 — Ot + No Vopo(St)VaQw(St; 2) la=p0(se) (35.45) 
47 
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This avoids importance sampling in the actor update because of the deterministic policy gradient, 
and avoids importance sampling in the critic update because of the use of Q-learning. 

If Qu is linear in w, and uses features of the form #(s,a) = a! Veie(s), where a is the vector 
representation of a, then we say the function approximator for the critic is compatible with the 
actor; in this case, one can show that the above approximation does not bias the overall gradient. 
The method has been extended in various ways. The DDPG algorithm of [Lil+16], which stands 
for “deep deterministic policy gradient”, uses the DQN method (Section 35.2.6) to update Q that 
is represented by deep neural networks. The TD3 algorithm [FHM18], which stands for “twin 
delayed DDPG”, extends DDPG by using double DQN (Section 35.2.5.2) and other heuristics to 
further improve performance. Finally, the D4PG algorithm [BM-+18], which stands for “distributed 
distributional DDPG”, extends DDPG to handle distributed training, and to handle distributional 
RL (i.e., working with distributions of rewards instead of expected rewards [BDM17]). 


35.3.6 Gradient-free methods 


The policy gradient estimator computes a “zeroth order” gradient, which essentially evaluates the 
function with randomly sampled trajectories. Sometimes it can be more efficient to use a derivative- 
free optimizer (Section 6.9), that does not even attempt to estimate the gradient. For example, 
[MGR18] obtain good results by training linear policies with random search, and [Sal+-17b] show 
how to use evolutionary strategies to optimize the policy of a robotic controller. 


35.4 Model-based RL 


Model-free approaches to RL typically need a lot of interactions with the environment to achieve 
good performance. For example, state of the art methods for the Atari benchmark, such as rainbow 
(Section 35.2.6), use millions of frames, equivalent to many days of playing at the standard frame rate. 
By contrast, humans can achieve the same performance in minutes [Tsi+-17]. Similarly, OpenAI’s 
robot hand controller [And+20] learns to manipulate a cube using 100 years of simulated data. 

One promising approach to greater sample efficiency is model-based RL (MBRL). In this 
approach, we first learn the transition model and reward function, pr(s’|s,a) and R(s,a), then 
use them to compute a near-optimal policy. This approach can significantly reduce the amount of 
real-world data that the agent needs to collect, since it can “try things out” in its imagination (i.e., 
the models), rather than having to try them out empirically. 

There are several ways we can use a model, and many different kinds of model we can create. Some 
of the algorithms mentioned earlier, such as MBIE and UCLR2 for provably efficient exploration 
(Section 35.1.5.3), are examples of model-based methods. MBRL also provides a natural connection 
between RL and planning (Section 34.6) [Sut90]. We discuss some examples in the sections below, 
and refer to [MBJ20; PKP21; MH20] for more detailed reviews. 


35.4.1 Model predictive control (MPC) 


So far in this chapter, we have focused on trying to learn an optimal policy 7,(s), which can then be 
used at run time to quickly pick the best action for any given state s. However, we can also avoid 
performing all this work in advance, and wait until we know what state we are in, call it s+, and 
then use a model to predict future states and rewards that might follow for each possible sequence of 
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Figure 35.6: Illustration of heuristic search. In this figure, the subtrees are ordered according to a depth-first 
search procedure. From Figure 8.9 of [SB18]. Used with kind permission of Richard Sutton. 


future actions we might pursue. We then take the action that looks most promising, and repeat the 
process at the next step. More precisely, we compute 


H-1 
aia H-1 = argmax E 5 R(St+h, at+h) + V (st+H) (35.46) 


Qt:t+H—1 h=0 


N JN 
E IS Is le Sls lale le Is E ls 


22where the expectation is over state sequences that might result from executing a444747—1 from state 
23 s4. Here, H is called the planning horizon, and Vie H) is an estimate of the reward-to-go at the 
24end of this H-step look-ahead process. This is known as receeding horizon control or model 
25 predictive control (MPC) [MM90; CA13]. We discuss some special cases of this below. 

26 


27 
ae 35.4.1.1 Heuristic search 


29Tf the state and action spaces are finite, we can solve Equation (35.46) exactly, although the time 
30complexity will typically be exponential in H. However, in many situations, we can prune off 
3lunpromising trajectories, thus making the approach feasible in large scale problems. 

32 In particular, consider a discrete, deterministic MDP where reward maximization corresponds to 
33 finding a shortest path to a goal state. We can expand the successors of the current state according 
34to all possible actions, trying to find the goal state. Since the search tree grows exponentially with 
35 depth, we can use a heuristic function to prioritize which nodes to expand; this is called best-first 
36 search, as illustrated in Figure 35.6. 

37 Ifthe heuristic function is an optimistic lower bound on the true distance to the goal, it is called 
38admissable; If we aim to maximize total rewards, admissibility means the heuristic function is an 
39 upper bound of the true value function. Admissibility ensures we will never incorrectly prune off 
40 parts of the search space. In this case, the resulting algorithm is known as A* search, and is optimal. 
41For more details on classical AI heuristic search methods, see [Pea84; RN19]. 


42 
135.4.1.2 Monte-Carlo tree search (MCTS) 


45 Monte-Carlo tree search or MCTS is similar to heuristic search, but learns a value function for 
46each encountered state, rather than relying on a manually designed heuristic (see e.g., [Mun14] for 
47 
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35.4. MODEL-BASED RL 


details). MCTSis inspired by UCB for bandits (Section 34.4.5), but applies to general sequential 
decision making problems including MDPs [KS06]. 

The MCTS method forms the basis of the famous AlphaGo and AlphaZero programs [Sil--16; 
Sil+-18], which can play expert-level Go, chess and shogi (Japanese chess), using a known model 
of the environment. The MuZero method of [Sch+20] and the Stochastic MuZero method of 
[Ant-+22] extend this to the case where the world model is also learned. The action-value functions 
for the intermediate nodes in the search tree are represented by deep neural networks, and updated 
using temporal difference methods that we discuss in Section 35.2. MCTS can also be applied to 
many other kinds of seqential decision problems, such as experiment design for sequentially creating 
molecules [SPW 18]. 


35.4.1.3 Trajectory optimization for continuous actions 


For continuous actions, we cannot enumerate all possible branches in the search tree. Instead, 
Equation (35.46) can be viewed as a nonlinear program, where at+t+g-—1 are the real-valued variables 
to be optimized. If the system dynamics are linear and the reward function corresponds to negative 
quadratic cost, the optimal action sequence can be solved mathematically, as in the linear-quadratic- 
Gaussian (LQG) controller (see e.g., [AM89; HR17|). However, this problem is hard in general and 
often solved by numerical methods such as shooting and collocation [Die+07; Raol0; Kal+11]. 
Many of them work in an iterative fashion, starting with an initial action sequence followed by a step 
to improve it. This process repeats until convergence of the cost. 

An example is differential dynamic programming (DDP) [JM70; TL05]. In each iteration, 
DDP starts with a reference trajectory, and linearizes the system dynamics around states on the 
trajectory to form a locally quadratic approximation of the reward function. This system can be 
solved using LQG, whose optimal solution results in a new trajectory. The algorithm then moves to 
the next iteration, with the new trajectory as the reference trajectory. 

Other alternatives are possible, including black-box (gradient-free) optimization methods like the 
cross entropy method. (see Section 6.9.5). 


35.4.2 Combining model-based and model-free 


In Section 35.4.1, we discussed MPC, which uses the model to decide which action to take at each 
step. However, this can be slow, and can suffer from problems when the model is inaccurate. An 
alternative is to use the learned model to help reduce the sample complexity of policy learning. 

There are many ways to do this. One approach is to generate rollouts from the model, and then 
train a policy or Q-function on the “hallucinated” data. This is the basis of the famous dyna method 
[Sut90]. In [Jan+19], they propose a similar method, but generate short rollouts from previously 
visited real states; this ensures the model only has to extrapolate locally. 

In [Web+17], they train a model to predict future states and rewards, but then use the hidden 
states of this model as additional context for a policy-based learning method. This can help overcome 
partial observability. They call their method imagination-augmented agents. A related method 
appears in [Jad+17], who propose to train a model to jointly predict future rewards and other 
auxiliary signals, such as future states. This can help in situations when rewards are sparse or absent. 
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11 Figure 35.7: (a) A cart-pole system being controlled by a policy learned by PILCO using just 17.5 seconds 
2 of real-world interaction. The goal state is marked by the red cross. The initial state is where the cart is 
13 stationary on the right edge of the workspace, and the pendulum is horizontal. For a video of the system 
l4learning, see https: //bit. ly/ 35fpLmk. (b) A low-quality robot arm being controlled by a block-stacking 
15 policy learned by PILCO using just 230 seconds of real-world interaction. From Figures 11, 12 from [DF R15]. 
16Used with kind permission of Marc Deisenroth. 

17 

18 

1935.4.3 MBRL using Gaussian processes 


+ This section gives some examples of dynamics models that have been learned for low-dimensional 
22 Continuous control problems. Such problems frequently arise in robotics. Since the dynamics are 
93 Often nonlinear, it is useful to use a flexible and sample-efficient model family, such as Gaussian 
94 Processes (Section 18.1). We will use notation like s and a for states and actions to emphasize they 
are vectors. 


25 


26 
2735.4.3.1 PILCO 


28We first describe the PILCO method [DR11; DFR15], which stands for “probabilistic inference for 
29learning control”. It is extremely data efficient for continuous control problems, enabling learning 
30 from scratch on real physical robots in a matter of minutes. 

31 PILCO assumes the world model has the form s41 = f (S+, at) + €+, where ep ~ N (0, ©), and f 
32is an unknown, continuous function.* The basic idea is to learn a Gaussian process (Section 18.1)) 
33 approximation of f based on some initial random trajectories, and then to use this model to generate 
34fantasy” rollout trajectories of length T, that can be used to evaluate the expected cost of the 


33 current policy, J(re) = ye Ea,~7o [c(Se)], where so ~ po. This function and its gradients wrt 0 


36can be computed deterministically, if a Gaussian assumption about the state distribution at each 
37step is made, because the Gaussian belief state can be propagated deterministically through the 
38GP model. Therefore, we can use deterministic batch optimization methods, such as Levenberg- 
39 Marquardt, to optimize the policy parameters 0, instead of applying SGD to sampled trajectories. 
(See https: //github.com/mathDR/jax-pilco for some JAX code.) 

4 Due to its data efficiency, it is possible to apply PILCO to real robots. Figure 35.7a shows the 
# results of applying it to solve a cart-pole swing-up task, where the goal is to make the inverted 
43 pendulum swing up by applying a horizontal force to move the cart back and forth. The state of the 


system s € R* consists of the position x of the cart (with x = 0 being the center of the track), the 
45 


464. An alternative, which often works better, is to use f to model the residual, so that sz+1 = St + f(s¢,@z) + €t. 
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velocity t, the angle @ of the pendulum (measured from hanging downward), and the angular velocity 
6. The control signal a € R is the force applied to the cart. The target state is s, = (0,*,7,x), 
corresponding to the cart being in the middle and the pendulum being vertical, with velocities 
unspecified. The authors used an RBF controller with 50 basis functions, amounting to a total of 
305 policy parameters. The controller was successfully trained using just 7 real world trials.° 

Figure 35.7b shows the results of applying PILCO to solve a block stacking task using a low- 
quality robot arm with 6 degrees of freedom. A separate controller was trained for each block. The 
state space s € R? is the 3d location of the center of the block in the arm’s gripper (derived from 
an RGBD sensor), and the control a € R* corresponds to the pulse widths of four servo motors. A 
linear policy was successfully trained using as few as 10 real world trials. 


35.4.3.2 GP-MPC 


[KD18a] have proposed an extension to PILCO that they call GP-MPC, since it combines a GP 
dynamics model with model predictive control (Section 35.4.1). In particular, they use an open-loop 
control policy to propose a sequence of actions, Qt:t+H—1, aS opposed to sampling them from a policy. 
They compute a Gaussian approximation to the future state trajectory, p(St+1:t+H|Qt:t+H-—1;, St), by 
moment matching, and use this to deterministically compute the expected reward and its gradient 
wrt 14441 (as opposed to the policy parameters 6). Using this, they can solve Equation (35.46) 
to find a@;.,,7;_,; finally, they execute the first step of this plan, až, and repeat the whole process. 

The advantage of GP-MPC over policy-based PILCO is that it can handle constraints more easily, 
and it can be more data efficient, since it continually updates the GP model after every step (instead 
of at the end of an trajectory). 


35.4.4 MBRL using DNNs 


Gaussian processes do not scale well to large sample sizes and high dimensional data. Deep neural 
networks (DNNs) work much better in this regime. However, they do not naturally model uncertainty, 
which can cause MPC methods to fail. We discuss various methods for representing uncertainty with 
DNNs in Section 17.1. Here, we mention a few approaches that have been used for MBRL. 

The deep PILCO method uses DNNs together with Monte Carlo dropout (Section 17.3.1) to 
represent uncertainty [GMAR16]. [Chu+18] proposed probabilistic ensembles with trajectory 
sampling or PETS, which represents uncertainty using an ensemble of DNNs (Section 17.3.9). 
Many other approaches are possible, depending on the details of the problem being tackled. 

Since these are all stochastic methods (as opposed to the GP methods above), they can suffer from 
a high variance in the predicted returns, which can make it difficult for the MPC controller to pick 
the best action. We can reduce variance with the common random number trick [KSN99], where 
all rollouts share the same random seed, so differences in J (mte) can be attributed to changes in 0 
but not other factors. This technique was used in PEGASUS [NJ00]° and in [HMD18]. 


5. 2 random initial trials, each 5 seconds, and then 5 policy-generated trials, each 2.5 seconds, totalling 17.5 seeconds. 
6. PEGASUS stands for “Policy Evaluation-of-Goodness And Search Using Scenarios”, where the term “scenario” refers 
to one of the shared random samples. 
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3 Figure 35.8: (a) Illustration of an agent interacting with the VizDoom environment. (The yellow blobs 
represent fireballs being thrown towards the agent by various enemies.) The agent has a world model, composed 
15 oF a vision system V and a memory RNN M, and has a controller C. (b) Detailed representation of the 
qg memory model. Here ht is the deterministic hidden state of the RNN at time t, which is used to predict the 
ay tent latent of the VAE, z141, using a mixture density network (MDN). Here T is a temperature parameter 


—used to increase the variance of the predictions, to prevent the controller from exploiting model inaccuracies. 
18 From Figures 4, 6 of [HS18]. Used with kind permission of David Ha. 


2235.4.5 MBRL using latent-variable models 


23 Tn this section, we describe some methods that learn latent variable models, rather than trying to 
24 predict dynamics directly in the observed space, which is hard to do when the states are images. 
25 


26 
9735.4.5.1 World models 


28The “world models” paper [HS18] showed how to learn a generative model of two simple video 
29 games (CarRacing and a VizDoom-like environment), such that the model can be used to train a 
3° policy entirely in simulation. The basic idea is shown in Figure 35.8. First, we collect some random 
31 experience, and use this to fit a VAE model (Section 21.2) to reduce the dimensionality of the images, 
32 q, € R64x64x3 to a latent z; E€ R4. Next, we train an RNN to predict p(z;+1|2;, a, hy), where hy 
33is the deterministic RNN state, and a; is the continuous action vector (3-dimensional in both cases). 
34The emission model for the RNN is a mixture density network, in order to model multi-modal futures. 
35 Finally, we train the controller using z; and h; as inputs; here z; is a compact representation of the 
36 current frame, and h; is a compact representation of the predicted distribution over 2441. 

37 The authors of [HS18] trained the controller using a derivative free optimizer called CMA-ES 
38(covariance matrix adaptation evolutionary strategy, see Section 6.9.6.2). It can work better than 
39 policy gradient methods, as discussed in Section 35.3.6. However, it does not scale to high dimensions. 
40 To tackle this, the authors use a linear controller, which has only 867 parameters.’ By contrast, 
4!VAE has 4.3M parameters and MDN-RNN 422k. Fortunately, these two models can be trained in an 
22 unsupervised way from random rollouts, so sample efficiency is less critical than when training the 
43 policy. 
44 

457, The input is a 32-dimensional z; plus a 256-dimensional h;, and there are 3 outputs. So the number of parameters 
46 is (32 + 256) x 3+ 3 = 867, to account for the weights and biases. 
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So far, we have described how to use the representation learned by the generative model as 
informative features for the controller, but the controller is still learned by interacting with the 
real world. Surprisingly, we can also train the controller entirely in “dream mode”, in which the 
generated images from the VAE decoder at time t are fed as input to the VAE encoder at time t + 1, 
and the MDN-RNN is trained to predict the next reward r;4, as well as 2,41. Unfortunately, this 
method does not always work, since the model (which is trained in an unsupervised way) may fail to 
capture task-relevant features (due to underfitting) and may memorize task-irrelevant features (due 
to overfitting). The controller can learn to exploit weaknesses in the model (similar to an adversarial 
attack) and achieve high simulated reward, but such a controller may not work well when transferred 
to the real world. 

One approach to combat this is to artificially increase the variance of the MDN model (by using 
a temperature parameter T), in order to make the generated samples more stochastic. This forces 
the controller to be robust to large variations; the controller will then treat the real world as just 
another kind of noise. This is similar to the technique of domain randomization, which is sometimes 
used for sim-to-real applications; see e.g., [MAZA18]. 


35.4.5.2 PlaNet and Dreamer 


In [HS18], they first learn the world model on random rollouts, and then train a controller. On harder 
problems, it is necessary to iterate these two steps, so the model can be trained on data collected by 
the controller, in an iterative fashion. 

In this section, we describe one method of this kind, known as PlaNet [Haf+19]. PlaNet 
uses a POMDP model, where z, are the latent states, s; are the observations, a; are the ac- 
tions, and r; are the rewards. It fits a recurrent state space model (Section 29.13.2) of the form 
p(Zz|Ze-1, @t—-1)p(Sz|Zz)p(re|2t) using variational inference, where the posterior is approximated by 
q(Zt|$14,@1+-1). After fitting the model to some random trajectories, the system uses the inference 
model to compute the current belief state, and then uses the cross entropy method to find an action 
sequence for the next H steps to maximize expected reward, by optimizing in latent space. The 
system then executes až, updates the model, and repeats the whole process. To encourage the 
dynamics model to capture long term trajectories, they use the “latent overshooting” training method 
described in Section 29.13.3. The PlaNet method outperforms model-free methods, such as A3C 
(Section 35.3.3.1) and D4PG (Section 35.3.5), on various image-based continuous control tasks, 
illustrated in Figure 35.9. 

Although PlaNet is sample efficient, it is not computationally efficient. For example, they use 
CEM with 1000 samples and 10 iterations to optimize trajectories with a horizon of length 12, which 
requires 120,000 evaluations of the transition dynamics to choose a single action. [AY 19] improve this 
by replacing CEM with differentiable CEM, and then optimize in a latent space of action sequences. 
This is much faster, but the results are not quite as good. However, since the whole policy is now 
differentiable, it can be fine-tuned using PPO (Section 35.3.4), which closes the performance gap at 
negligible cost. 

A recent extension of the PlaNet paper, known as Dreamer, was proposed in [Haf+20]. In this 
paper, the online MPC planner is replaced by a policy network, 7(a;|z;), which is learned using 
gradient-based actor-critic in latent space. The inference and generative models are trained by 
maximizing the ELBO, as in PlaNet. The policy is trained by SGD to maximize expected total 
reward as predicted by the value function, and the value function is trained by SGD to minimize 
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(a) Cartpole (b) Reacher (c) Cheetah (d) Finger (e) Cup (f) Walker 


Figure 35.9: Illustration of some image-based control problems used in the PlaNet paper. Inputs are 64 x 64 x 3. 
(a) The cartpole swingup task has a fixed camera so the cart can move out of sight, making this a partially 
10 observable problem. (b) The reacher task has a sparse reward. (c) The cheetah running task includes both 
1l contacts and a larger number of joints. (d) The finger spinning task includes contacts between the finger and 
12the object. (e) The cup task has a sparse reward that is only given once the ball is caught. (f) The walker 
13 task requires balance and predicting difficult interactions with the ground when the robot is lying down. From 
14 Figure 1 of [Haf+19]. Used with kind permission of Danijar Hafner. 

15 
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ig MSE between predicted future reward and the TD-A estimate (Section 35.2.2). They show that 
jo Dreamer gives better results than PlaNet, presumably because they learn a policy to optimize the long 
99 berm reward (as estimated by the value function), rather than relying on MPC based on short-term 
>, rollouts. 

21 


22 
2335.4.6 Robustness to model errors 
~~ The main challenge with MBRL is that errors in the model can result in poor performance of the 
og Lesulting policy, due to the distribution shift problem (Section 19.2). That is, the model is trained 
97° predict states and rewards that it has seen using some behavior policy (e.g., the current policy), 
-7 and then is used to compute an optimal policy under the learned model. When the latter policy is 
99 (ollowed, the agent will experience a different distribution of states, under which the learned model 
39 may not be a good approximation of the real environment. 

We require the model to generalize in a robust way to new states and actions. (This is related to 
39 the off-policy learning problem that we discuss in Section 35.5.) Failing that, the model should at 
33 east be able to quantify its uncertainty (Section 19.3). These topics are the focus of much recent 
34research (see e.g., [Luo+19; Kur+19; Jan+19; Isl+19; Man+19; WB20; Eys+21]). 
35 
3635.5 Off-policy learning 
37 
38 We have seen examples of off-policy methods such as Q-learning. They do not require that training 
39data be generated by the policy it tries to evaluate or improve. Therefore, they tend to have greater 
40 data efficiency than their on-policy counterparts, by taking advantage of data generated by other 
41 policies. They are also easier to be applied in practice, especially in domains where costs and risks of 
42 following a new policy must be considered. This section covers this important topic. 
43 A key challenge in off-policy learning is that the data distribution is typically different from the 
44desired one, and this mismatch must be dealt with. For example, the probability of visiting a state s 
45at time t in a trajectory depends not only on the MDP’s transition model, but also on the policy 
46 that is being followed. If we are to estimate J(), as defined in Equation (35.15), but the trajectories 
47 
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35.5. OFF-POLICY LEARNING 


are generated by a different policy 7’, simply averaging rewards in the data gives us J(m’), not J(z). 
We have to somehow correct for the gap, or “bias”. Another challenge is that off-policy data can also 
make an algorithm unstable and divergent, which we will discuss in Section 35.5.3. 

Removing distribution mismatches is not unique in off-policy learning, and is also needed in 
supervised learning to handle covariate shift (Section 19.2.3.1), and in causal effect estimation 
(Chapter 36), among others. Off-policy learning is also closely related to offline reinforcement 
learning (also called batch reinforcement learning): the former emphasizes the distributional 
mismatch between data and the agent’s policy, and the latter emphasizes that the data is static and 
no further online interaction with the environment is allowed [LP03; EGW05; Lev+20]. Clearly, in the 
offline scenario with fixed data, off-policy learning is typically a critical technical component. Recently, 
several datasets have been prepared to facilitate empirical comparisons of offline RL methods (see 
e.g., [Gul+20; Fu+20]). 

Finally, while this section focuses on MDPs, most methods can be simplified and adapted to the 
special case of contextual bandits (Section 34.4). In fact, off-policy methods have been successfully 
used in numerous industrial bandit applications (see e.g., [Li+10; Bot+13; SJ15; HLR16]). 


35.5.1 Basic techniques 


We start with four basic techniques, and will consider more sophisticated ones in subsequent sections. 


The off-policy data is assumed to be a collection of trajectories: D = {tO cen, where each 


trajectory is a sequence as before: 7 = = (sf ) at" ) rË ) 3) ...). Here, the reward and next states 


are sampled according to the reward and transition models; the actions are chosen by a behavior 
policy, denoted m, which is different from the target policy, Te, that the agent is evaluating or 
improving. When m, is unknown, we are in a behavior-agnostic off-policy setting. 


35.5.1.1 Direct method 


A natural approach to off-policy learning starts with estimating the unknown reward and transition 
models of the MDP from off-policy data. This can be done using regression and density estimation 
methods on the reward and transition models, respectively, to obtain Ê and Ê; see Section 35.4 for 
further discussions. These estimated models then give us an inexpensive way to (approximately) 
simulate the original MDP, and we can apply on-policy methods on the simulated data. This method 
directly models the outcome of taking an action in a state, thus the name direct method, and is 
sometimes known as regression estimator and plug-in estimator. 

While the direct method is natural and sometimes effective, it has a few limitations. First, a small 
estimation error in the simulator has a compounding effect in long-horizon problems (or equivalently, 
when the discount factor y is close to 1). Therefore, an agent that is optimized against an MDP 
simulator may overfit the estimation errors. Unfortunately, learning the MDP model, especially the 
transition model, is generally difficult, making the method limited in domains where Ê and Ê can be 
learned to high fidelity. See Section 35.4.6 for a related discussion. 


35.5.1.2 Importance sampling 


The second approach relies on importance sampling (IS) (Section 11.5) to correct for distributional 
mismatches in the off-policy data. To demonstrate the idea, consider the problem of estimating the 
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target policy value J(z,) with a fixed horizon T. Correspondingly, the trajectories in D are also of 
length T. Then, the IS off-policy estimator, first adopted by [PSS00], is given by 


Dir) E , 
Jis (te) = — p(T ue iel D (35.47) 


It can be verified that Er, e (r.)] = J(me), that is, Jtg (ze) is unbiased, provided that p(t|m) > 0 


IO 100 IN ID Io Ie low IN Ie 


per vite) 
p(T |T)? 
10 data is sampled from m, and not Te. Furthermore, this ratio does not depend on the MDP models, 


‘because for any trajectory T = (so, a0, To, S1,- Šp), we have from Equation (34.62) that 


whenever p(t|7-) > 0. The importance ratio, is used to compensate for the fact that the 


p(T|Te) 2 p(so) IIo Telat|St)Pr(St+1lSt, O4)PR(Te| St, at, $441) _ TT Telar|st) (35.48) 


P(T|7™)  p(so) a Tplar|St)Pr(St+1lSt, A )PR(Te|Se, at, St41) 320 T™(az| St) 
This simplification makes it easy to apply IS, as long as the target and behavior policies are known. If 
the behavior policy is unknown, we can estimate it from D (using e.g., logistic regression or DNNs), and 
replace m, by its estimate 7 in Equation (35.48). For convenience, define the per-step importance 
ratio at time t by p(T) = me(az|s+)/m(az|s¢), and similarly, 6:(7) * me (at|se)/ito (az sz). 

Although IS can in principle eliminate distributional mismatches, in practice its usability is often 
91 Limited by its nee high eas Indeed, the importance ratio in Equation (35.47) can be 
arbitrarily large if p(t re) >> p(r |). There are many improvements to the basic IS estimator. 
"One improvement is based on the observation that the reward r; is independent of the trajectory 
4 beyond time t. This leads to a per-decision importance sampling variant that often yields lower 
95 Variance (see Section 11.6.2 for a statistical motivation, and [LBB20] for a further discussion): 


S 
Is Is l In ES lm [= le Is | 


26 n T-1 
27. Jppis(te) * = 2 5 Il pu (Tye) (35.49) 
28 i=1 t=0 <i 


29'There are many other variants such as self-normalized IS and truncated IS, both of which aim to 
30 reduce variance possibly at the cost of a small bias; precise expressions of these alternatives are found, 
31e.g., in [Liu+18b]. In the next subsection, we will discuss another systematic way to improve IS. 

32 IS may also be applied to improve a policy against the policy value given in Equation (35.15). 
33 However, directly applying the calculation of Equation (35.48) runs into a fundamental issue with IS, 
34which we will discuss in Section 35.5.2. For now, we may consider the following approximation of 


35 policy value, averaging over the state distribution of the behavior policy: 
36 


ae Jy (76) Ope (s) [Vx(s)] = "p (s) £ rolls Q(s,0) (35.50) 


38 


“Differentiating this and ignoring the term VeQ,(s,a), as suggested by [DWS12], gives a way to 
„1 (approximately) estimate the off-policy policy-gradient using a one-step IS correction ratio: 


42 i 

a3 VoJ(To) = Eps (s) E Vorolalo)@ (s0) 

44 a 

15 = Ey opty [EE ve ow ra(als@n(s,0) 
a px (s)B(als) B(als) Byes 
47 
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35.5. OFF-POLICY LEARNING 


Finally, we note that in the tabular MDP case, there exists a policy 7, that is optimal in all states 
(Section 34.5.5). This policy maximizes J and J, simultaneously, so Equation (35.50) can be a good 
proxy for Equation (35.15) as long as all states are “covered” by the behavior policy mẹ. The situation 
is similar when the set of value functions or policies under consideration is sufficiently expressive: an 
example is a Q-learning like algorithm called Retrace [Mun+16; ASN20]. Unfortunately, in general 
when we work with parametric families of value functions or policies, such a uniform optimality is 
lost, and the distribution of states has a direct impact on the solution found by the algorithm. We 
will revisit this problem in Section 35.5.2. 


35.5.1.3 Doubly robust 


It is possible to combine the direct and importance sampling methods discussed previously. To 
develop intuition, consider the problem of estimating J (me) in a contextual bandit (Section 34.4), 
that is, when T = 1 in D. The doubly robust (DR) estimator is given by 


R i (abs) í a, (i) (i ~ (i 
TOERE naay CO — O80. 95")) EPGP) (35.51) 
izi \Îelao lso ) 


where Q is an estimate of Qr., which can be obtained using methods discussed in Section 35.2, 


and V(s) = Ez. (a|s) [ĝis a)|. If îe = Tp, the term Q is canceled by V on average, and we get the 


IS estimate that is unbiased; if Q = Qr., the term Q is canceled by the reward on average, and 
we get the estimator as in the direct method that is also unbiased. In other words, the estimator 
Equation (35.51) is unbiased, as long as one of the estimates, 7, and Q, is right. This observation 
justifies the name doubly robust, which has its origin in causal inference (see e.g., [BR05]). 

The above DR estimator may be extended to MDPs recursively, starting from the last step. Given 
a length-T trajectory T, define Jor {T| £0, and for t < T, 


Jorit] = V (se) + (T) (rı + Jprit + 1] - Q(si,a1)) (35.52) 


where Q(st, az) is the estimated cumulative reward for the remaining T — t steps. The DR estimator 
of J(m-), denoted Jpr(7e), is the average of Jpr[0] over all n trajectories in D [JL16]. It can be 
verified (as an exercise) that the recursive definition is equivalent to 


T-1 t 
Jpr[0] = V(s0) + >> (1 pr! o) J (r + V (st41) — Q(se, a.)) (35.53) 
t=0 \t’=0 


This form can be easily generalized to the infinite-horizon setting by letting T — oo [TB16]. Other 
than double robustness, the estimator is also shown to result in minimum variance under certain 
conditions [JL16]. Finally, the DR estimator can be incorporated into policy gradient for policy 
optimization, to reduce gradient estimation variance [HJ20]. 


35.5.1.4 Behavior regularized method 


The three methods discussed previously do not impose any constraint on the target policy mre. 
Typically, the more different Te is from 7mp, the less accurate our off-policy estimation can be. 
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Therefore, when we optimize a policy in offline RL, a natural strategy is to favor target policies that 
are “close” to the behavior policy. Similar ideas are discussed in the context of conservative policy 
gradient (Section 35.3.4). 

One approach is to impose a hard constraint on the proximity between the two policies. For 
example, we may modify the loss function of DQN (Equation (35.14)) as follows 


P : 2 
LPN (w) 2 aara oD [e +y max lr(a'ls') Qw- (5, 0")] — Qu ls, a)) | (35.54) 


n:D(T, Tp) KE 
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10In the above, we replace the max, operation by an expectation over a policy that stays close enough 
11to the behavior policy, measured by some distance function D. For various instantiations and further 
i2details, see e.g, [FMP19; Kum+19a]. 

We may also impose a soft constraint on the proximity, by penalizing target policies that are too 
different. The DQN loss function can be adapted accordingly: 


LEW (w) Ê Esanen | (7 + ymax Eriaren Ruw- (5, a')] — 0D (n(s!), 20(5")) — Que(s,2))”| 
(35.55) 


This idea has been used in contextual bandits [SJ15] and empirically studied in MDPs by [WTN19]. 

There are many choices for the function D, such as the KL-divergence, for both hard and soft 
constraints. More detailed discussions and examples can be found in [Lev+20]. 

Finally, behavior regularization and previous methods like IS can be combined, where the former 
ensures lower variance and greater generalization of the latter (e.g., [5J15]). Furthermore, most 
54 Proposed behavior regularized methods consider one-step difference in D, comparing m(s) and 7(s) 
a5 Conditioned on s. In many cases, it is desired to consider the difference between the long-term 
9g distributions, pg and p°, which we will discuss next. 


27 
23 35.5.2 The curse of horizon 


2°2The IS and DR approaches presented in the previous section all rely on an importance ratio to 
3 correct distributional mismatches. The ratio depends on the entire trajectory, and its variance grows 
31 exponentially in the trajectory length T. Correspondingly, the off-policy estimate of either the policy 
32 Value or policy gradient can suffer an exponentially large variance (and thus very low accuracy), a 
33 challenge called the curse of horizon [Liu+18b]. Policies found by approximate algorithms like 
34 Q-learning and off-policy actor-critic often have hard-to-control error due to distribution mismatches. 
35 This section discusses an approach to tackling this challenge, by considering corrections in the 
36 state-action distribution, rather than in the trajectory distribution. This change is critical: [Liu + 18b] 
37 describes an example, where the state-action distributions under the behavior and target policies 


38 are identical, but the importance ratio of a trajectory grows exponentially large. It is now more 


32 convenient to assume the off-policy data consists of a set of transitions: D = { (si, ai, fi, si) bien 
where (s;,a;) ~ pp (some fixed but unknown sampling distribution, such as PB), and r; and 
Hg are sampled from the MDP’s reward and transition models. Given a policy m, we aim to 
Z estimate the correction ratio ¢,(s,a) = p% (s, a)/pp(s,a), as it allows us to rewrite the policy value 
“(Equation (35.15)) as 

1 1 
7 J() = ay cp? (s,a) [R(s,a)] = I-y 
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35.5. OFF-POLICY LEARNING 


For simplicity, we assume the initial state distribution po is known, or can be easily sampled from. 
This assumption is often easy to satisfy in practice. 
The starting point is the following linear program formulation for any given 7: 


ma DeXalle))) 8: d(s,a) = (1~)H0(s)n (als) + y $ p(s|5, a)d(5, a)r (als) V(s, @) 


(35.57) 


where Dy is the f-divergence (Section 2.7.1). The constraint is a variant of Equation (34.81), giving 
similar flow conditions in the space of S x A under policy m. Under mild conditions, p°° is only 
solution that satisfies the flow constraints, so the objective does not affect the solution, but will 
facilitate the derivation below. We can now obtain the Lagrangian, with multipliers {v(s,a)}, and 
use the change-of-variables ¢(s, a) = d(s,a)/pp(s,a) to obtain the following optimization problem: 


max min £(¢, v) = Up (s,a) [—F(cta, a)| F (1 = y) talara) [v (s, a)| (35.58) 


+ Er(a'|s')p(s'|s,a)po(s,a) (8, a) Gus, a) —v(s,0))] 


It can be shown that the saddle point to Equation (35.58) must coincide with the desired correction 
ratio ¢,. In practice, we may parameterize Ç and v, and apply two-timescales stochastic gradient 
descent /ascent on the off-policy data D to solve for an approximate saddle-point. This is the 
DualDICE method [Nac+19a], which is extended to GenDICE [Zha+20c]. 

Compared to the IS or DR approaches, Equation (35.58) does not compute the importance ratio of 
a trajectory, thus generally has a lower variance. Furthermore, it is behavior-agnostic, without having 
to estimate the behavior policy, or even to assume data consists of a collection of trajectories. Finally, 
this approach can be extended to be doubly robust (e.g., [UHJ20]), and to optimize a policy [Nac+19b] 
against the true policy value J(7) (as opposed to approximations like Equation (35.50)). For more 
examples along this line of approach, see [ND20] and the references therein. 


35.5.3 The deadly triad 


Other than introducing bias, off-policy data may also make a value-based RL method unstable and 
even divergent. Consider the simple MDP depicted in Figure 35.10a, due to [Bai95]. It has 7 states 
and 2 actions. Taking the dashed action takes the environment to the 6 upper states uniformly at 
random, while the solid action takes it to the bottom state. The reward is 0 in all transitions, and 
y = 0.99. The value function Vw uses a linear parameterization indicated by the expressions shown 
inside the states, with w € R8. The target policies 7 always chooses the solid action in every state. 
Clearly, the true value function, V,(s) = 0, can be exactly represented by setting w = 0. 

Suppose we use a behavior policy b to generate a trajectory, which chooses the dashed and solid 
actions with probabilities 6/7 and 1/7, respectively, in every state. If we apply TD(0) on this 
trajectory, the parameters diverge to oo (Figure 35.10b), even though the problem appears simple! 
In contrast, with on-policy data (that is, when b is the same as 7), TD(0) with linear approximation 
can be guaranteed to converge to a good value function approximate [TR97]. 

The divergence behavior is demonstrated in many value-based bootstrapping methods, including 
TD, Q-learning, and related approximate dynamic programming algorithms, where the value function 
is represented either linearly (like the example above) or nonlinearly [Gor95; Ber19]. The root cause 
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15 Sutton. 

16 

17 

18 


12of these divergence phenomena is that the contraction property in the tabular case (Equation (34.75)) 
20 may no longer hold when V is approximated by Vw. An RL algorithm can become unstable when it 
2l has these three components: off-policy learning, bootstrapping (for faster learning, compared to MC), 
22function approximation (for generalization in large scale MDPs). This combination is known as the 
23 deadly triad [SB18]. It highlights another important challenge introduced by off-policy learning, 
24 and is a subject of ongoing research (e.g., [van+18; Kum+19a]). 

25 A general way to ensure convergence in off-policy learning is to construct an objective function 
26 function, the minimization of which leads to a good value function approximation; see [SB18, Ch. 11] 
27 for more background. A natural candidate is the discrepancy between the left and right hand sides of 
28the Bellman optimality equation Equation (34.70), whose unique solution is V,. However, the “max” 
29 operator is not friendly to optimization. Instead, we may introduce an entropy term to smooth the 


39 sreedy policy, resulting in a differential square loss in path consistency learning (PCL) |[Nac+17]: 
31 


32 1 
3 min cPOL(V, r) 2 E (r + yV(s!) — Alog-a(a|s) — V(s))? (35.59) 
A 


35 where the expectation is over (s, a,r, s’) tuples drawn from some off-policy distribution (e.g., uniform 
36 over D). Minimizing this loss, however, does not result in the optimal value function and policy in 
37 general, due to an issue known as “double sampling” [SB18, Sec. 11.5]. 


38 This problem can be mitigated by introducing a dual function in the optimization [Dai+18] 


~  minmax L5BFED(V q; v) £ E [rís a)(r + yV(s') — Alog x(a|s) — V(s))? — v(s,a)?/2 (35.60) 


41 Vir v 

42 

43where v belongs to some function class (e.g., a DNN [Dai+18] or RKHS [FLL19]). It can be shown 
44that optimizing Equation (35.60) forces v to model the Bellman error. So this approach is called 
45smoothed Bellman error embedding, or SBEED. In both PCL and SBEED, the objective can 
46 be optimized by gradient-based methods on parameterized value functions and policies. 

47 
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35.6. CONTROL AS INFERENCE 


Figure 35.11: A graphical model for optimal control. States and actions are observed, while optimality 
variables are not. Adapted from Figure 1b of [Lev18]. 


35.6 Control as inference 


In this section, we will discuss another approach to policy optimization, by reducing it to probabilistic 
inference. This is called control as inference, see e.g., [Att03; TS06; Tou09; BT12; KGO12; HR17; 
Lev18]. This approach allows one to incorporate domain knowledge in modeling, and apply powerful 
tools from approximate inference (see e.g., Chapter 7), in a consistent and flexible framework. 


35.6.1 Maximum entropy reinforcement learning 


We now describe a graphical model that exemplifies such a reduction, which results in RL algorithms 
that are closely related to some discussed previously. The model allows a trade-off between reward 
and entropy maximization, and recovers the standard RL setting when the entropy part vanishes in 
the trade-off. Our discussion mostly follows the approach of [Lev18]. 

Figure 35.11 gives a probabilistic model, which not only captures state transitions as before, but 
also introduces a new variable, o;. This variable is binary, indicating whether the action at time t is 
optimal or not, and has the following probability distribution: 


plor = 1]s:, at) = exp(A~! R(s¢, az)) (35.61) 


for some temperature parameter 4 > 0 whose role will be clear soon. In the above, we have 
assumed without much loss of generality that R(s,a) < 0, so that Equation (35.61) gives a valid 
probability. Furthermore, we can assume a non-informative, uniform action prior, p(a¢|s,), to simplify 
the exposition, for we can always push p(a;|s;) into Equation (35.61). Under these assumptions, the 
likelihood of observing a length-T trajectory 7, when optimality achieved in every step, is: 


T-1 
p(T |Oo:r—1 = 1) x p(T, 00:r-1 = 1) « p(so) | | plor = 1st, a) pr (si4i se, a) 
t=0 
T-1 1 T=! 
p(s S441|Sz, at) exp | + R(s1,a 35.62 
= (30) [I pr t+1/St t 1693 t ») ( ) 


The intuition of Equation (35.62) is clearest when the state transitions are deterministic. In this 
case, pr(Sz41|St, at) is either 1 or 0, depending on whether the transition is dynamically feasible or 
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not. Hence, p(T|oo:r-1 = 1) is either proportional to exp(A~+ ar R(s+,at)) if T is feasible, or 0 
otherwise. Maximizing reward is equivalent to inferring a trajectory with maximum p(7|09:r_1 = 1). 
The optimal policy in this probabilistic model is given by 
pubera Pls, alonT-1 =1) _ plonr-1 = 1st, a)pla|si)plst) 
p(stlosr-1 = 1) p(Or-r-1 = 1s+)p(s+) 
P(Or:7-1 = 1|51, at) 
P(Or7-1 = 1|s+) 


(35.63) 
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The two probabilities in Equation (35.63) can be computed as follows, starting with p(or_, = 
1|sp_1,a@7_1) = exp(A“' R(sr_1,a7-1)), 


p(Oz:7—-1 = 1|8¢, a) = f plotyi:T-1 = 1ļst41)pr(st+1lst, at) exp(A7! R( Sz, at) )dse41 (35.64) 
s 


plosnr-1 = 1]5¢) = / plonr-1 = 1|8¢, az) p(az|sz)daz (35.65) 
A 


The calculation above is expensive. In practice, we can approximate the optimal policy using a 
parametric form, 7@(a,|s;). The resulted probability of trajectory T now becomes 


T-1 


Po(T) = p(s1) I] Pr (St+1|8t, at)To(ar|st) (35.66) 
t=0 


If we optimize 0 so that Dg (pe(T) || p(T|00o:r-1 = 1)) is minimized, which can be simplified to 


T-1 
Dri (Pe(T) || p(T|00:r-1 = 1)) = -Epo p d~! R(s;, a2) + H(s9(s,))| + const (35.67) 


t=0 
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so where the constant term only depends on the uniform action prior p(a;|s;), but not 8. In other words, 
3, the objective is to maximize total reward, with an entropy regularization favoring more uniform 
39 policies. Thus this approach is called maximum entropy RL, or MERL. If 7 can represent all 
33 Stochastic policies, a softmax version of the Bellman equation can be obtained for Equation (35.67): 


34 


35 Q (St, at) = ATTR(s:, at) + Dpr (stilata) hoe f exp(Q. (S141 desa))da| (35.68) 
36 A 


3twith the convention that Q,(s7,a) = 0 for all a, and the optimal policy has a softmax form: 
38 T, (az|s1) «x exp(Q.(s+,a+)). Note that the Q., above is different from the usual optimal Q-function 
39(Equation (34.71)), due to the introduction of the entropy term. However, as \ — 0, their difference 
vanishes, and the softmax policy becomes greedy, recovering the standard RL setting. 

2 The soft actor-critic (SAC) algorithm [Haa+18b; Haa+18c] is an off-policy actor-critic method 


22 whose objective function is equivalent to Equation (35.67) (by taking T to oo): 
43 


sa SAC (9) £ Opes, (s)ro (als) [R(s, a) + A H(779(s))] (35.69) 
45 

46 Note that the entropy term has also the added benefit of encouraging exploration. 

47 
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35.6. CONTROL AS INFERENCE 


To compute the optimal policy, similar to other actor-critic algorithms, we will work with the “soft” 
state- and action-function approximations, parameterized by w and u, respectively: 


Qw(s,a) = R(s,a) + YEpr(s'|s,a) [Vu(s’, a’) — A log 76 (a’|s")] (35.70) 
Vu(s,a) = Alog X` exp(A7'Qu(s, a)) (35.71) 


This induces an improved policy (with entropy regularization): 7,,(a|s) = exp(A~!Qw(s, @))/Zw(s), 
where Zw(s) = >>, exp(A~'Qw(s, @)) is the normalization constant. We then perform a soft policy 
improvement step to update 0 by minimizing E [Dx (7@(s) || tw(s))] where the expectation may be 
approximated by sampling s from a replay buffer D. 

In [Haa+18c; Haa+18b], they show that the SAC method outperforms the off-policy DDPG 
algorithm (Section 35.3.5) and the on-policy PPO algorithm (Section 35.3.4) by a wide margin on 
various continuous control tasks. For more details, see [Haa+18c]. 

There is a variant of soft actor-critic, which only requires to model the action-value function. It is 
based on the observation that both the policy and soft value function can be induced by the soft 
action-value function as follows: 


2 =A exp (A7 tQuw(s,a)) (35.72) 


Tw (als) = exp (A> “(Qala a) = Vw(s))) (35.73) 


We then only need to learn w, using approaches similar to DQN (Section 35.2.6). The resulting 
algorithm, soft Q-learning [SAC17], is convenient if the number of actions is small (when A is 
discrete), or if the integral in obtaining Vw from Qw is easy to compute (when A is continuous). 

It is interesting to see that algorithms derived in the maximum entropy RL framework bears a 
resemblance to PCL and SBEED in Section 35.5.3, both of which were to minimize an objective 
function resulting from the entropy-smoothed Bellman equation. 


35.6.2 Other approaches 


VIREL is an alternative model to maximum entropy RL [Fel+19]. Similar to soft actor-critic, it uses 
an approximate action-value function, Qw, a stochastic policy, 7g, and a binary optimality random 
variable o, at time t. A different probability model for o, is used 


(Get at) — Maxa Qw (St, 2) 
Aw 


The temperature parameter Aw is also part of the parameterization, and can be updated from data. 
An EM method can be used to maximize the objective 


L(w, 0) = Eps) | Tralala) [229] + H(ro(s)) (35.75) 


for some distribution p that can be conveniently sampled from (e.g., in a replay buffer). The algorithm 
may be interpreted as an instance of actor-critic. In the E-step, the critic parameter w is fixed, and 
the actor parameter @ is updated using gradient accent with stepsize ng (for policy improvement): 


0 — 0 + noVoLlw,0) (35.76) 


plor = 1|s+, a+) = exp (35.74) 
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In the M-step, the actor parameter is fixed, and the critic parameter is updated (for policy evaluation): 
w + wt wVwl(w, 9) (35.77) 


Finally, there are other possibilities of reducing optimal control to probabilistic inference, in 
addition to MERL and VIREL. For example, we may aim to maximize the expectation of the 
trajectory return G, by optimizing the policy parameter 0: 


J(Te) = [et yptr|e)ar (35.78) 
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10 Tt can be interpreted as a pseudo-likelihood function, when the G(r) is treated as probability density, 
and solved (approximately) by a range of algorithms (see e.g., [PS07; Neull; Abd+18]). Interestingly, 
some of these methods have a similar objective as MERL (Equation (35.67)), although the distribution 
involving 0 appears in the second argument of Dgr. As discussed in Section 2.7.1, this forward 
KL-divergence is mode-covering, which in the context of RL is argued to be less preferred than the 
mode-seeking, reverse KL-divergence used by MERL. For more details and references, see [Lev 18]. 
Control as inference is also closely related to active inference; this is based on the free energy 
principle which is popular in neuroscience (see e.g., [Fri09; Buc+-17; SKM18; Gerl9; Maz-+22]). 
~The FEP is equivalent to using variational inference (see Section 10.1) to perform state estimation 
„o Perception) and parameter estimation (learning). In particular, consider a latent variable model 
— with hidden states s, observations y and parameters 0. Following Section 10.1.1, we define the 
2... : : 
2 Variational free energy to be F(o) = Dga (q(s, Aly) || p(s, y,@)). State estimation corresponds to 
93 Solving ming(s|y) *(y), and parameter estimation corresponds to solving mingg|y) F (y), just as in 
54 Variational Bayes EM (Section 10.2.5). (Minimizing the VFE for certain hierarchical Gaussian models 
also forms the foundation of predictive coding, which we discuss in Supplementary Section 8.1.4.) 
To extend this to decision making problems we can define the expected free energy as F(a) = 
= ‘g(yla) |F (y)], where q(y|a@) is the posterior predictive distribution over observations given actions 
ag Sequence a. The connection to control as inference is explained in [Mil+20; WIP20; LOW21]. 
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29 
3035.6.3 Imitation learning 


31Tn previous sections, an RL agent is to learn an optimal sequential decision making policy so that the 
32 total reward is maximized. Imitation learning (IL), also known as apprenticeship learning and 


33learning from demonstration (L£D), is a different setting, in which the agent does not observe 


34rewards, but has access to a collection Dexp of trajectories generated by an expert policy Texp; that 


Big r= (So, a0, $1,41,...,87) and at ~ Texp(St) for T E Dexp. The goal is to learn a good policy by 
36 imitating the expert, in the absence of reward signals. IL finds many applications in scenarios where 


3T we have demonstrations of experts (often humans) but designing a good reward function is not easy, 


38 such as car driving and conversational systems. See [Osa+18] for a survey up to 2018. 


39 
135.6.3.1 Imitation learning by behavior cloning 


42A natural method is behavior cloning, which reduces IL to supervised learning; see [Pom89] for 
43an early application to autonomous driving. It interprets a policy as a classifier that maps states 
44(inputs) to actions (labels), and finds a policy by minimizing the imitation error, such as 

45 


8 min Egg (6) [Da (Texels) II 7(s))] (35.79) 
47 
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35.6. CONTROL AS INFERENCE 


where the expectation wrt Drrexp MAY be approximated by averaging over states in Dexp. A challenge 
with this method is that the loss does not consider the sequential nature of IL: future state distribution 
is not fixed but instead depends on earlier actions. Therefore, if we learn a policy 7 that has a low 
imitation error under distribution p% „, as defined in Equation (35.79), it may still incur a large 
error under distribution p? (when the policy 7 is actually run). Further expert demonstrations or 
algorithmic augmentations are often needed to handle the distribution mismatch (see e.g., [DLM09; 
RGB11)). 


35.6.3.2 Imitation learning by inverse reinforcement learning 


An effective approach to IL is inverse reinforcement learning (IRL) or inverse optimal control 
(IOC). Here, we first infer a reward function that “explains” the observed expert trajectories, and 
then compute a (near-)optimal policy against this learned reward using any standard RL algorithms 
studied in earlier sections. The key step of reward learning (from expert trajectories) is the opposite 
of standard RL, thus called inverse RL [NR00a]. 

It is clear that there are infinitely many reward functions for which the expert policy is optimal, 
for example by several optimality-preserving transformations [NHR99]. To address this challenge, 
we can follow the maximum entropy principle (Section 2.3.7), and use an energy-based probability 
model to capture how expert trajectories are generated [Zie+08]: 


T-1 


p(T) x exp ( X` Ro(s:, ar)) (35.80) 
t=0 


where Rg is an unknown reward function with parameter 0. Abusing notation slightly, we denote 
by Re(r) = yy. Ro(st,a4)) the cumulative reward along the trajectory T. This model assigns 
exponentially small probabilities to trajectories with lower cumulative rewards. The partition function, 
Ze £ J, exp(Re(T)), is in general intractable to compute, and must be approximated. Here, we can 
take a sample-based approach. Let Dexp and D be the sets of trajectories generated by an expert, and 
by some known distribution q, respectively. We may infer @ by maximizing the likelihood, p(Dexp|@), 
or equivalently, minimizing the negative log-likelihood loss 


£(0)=-—*— E Ro(r) +log + ents) (35.81) 


Dol 2, 


The term inside the log of the loss is an importance sampling estimate of Z that is unbiased as long 
as q(T) > 0 for all r. However, in order to reduce the variance, we can choose q adaptively as @ is 
being updated. The optimal sampling distribution (Section 11.5), q(T) x exp(Rg(7)), is hard to 
obtain. Instead, we may find a policy 7 which induces a distribution that is close to qx, for instance, 
using methods of maximum entropy RL discussed in Section 35.6.1. Interestingly, the process above 
produces the inferred reward Rg as well as an approximate optimal policy 7. This approach is used 
by guided cost learning [FLA16], and found effective in robotics applications. 


35.6.3.3 Imitation learning by divergence minimization 


We now discuss a different, but related, approach to IL. Recall that the reward function depends 
only on the state and action in an MDP. It implies that if we can find a policy 7, so that p% (s, a) 
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and Ph (5 a) are close, then 7 receives similar long-term reward as Texp, and is a good imitation of 
Texp in this regard. A number of IL algorithms find 7 by minimizing the divergence between p% and 
Prep We will largely follow the exposition of [GZG19]; see [Ke+19b] for a similar derivation. 

Let f be a convex function, and Dy the f-divergence (Section 2.7.1). From the above intuition, we 


want to minimize Dr (oa 


pe). Then, using a variational approximation of De [NWJ10a], we can 
solve the following optimization problem for r: 
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min max pee (8,4) [Tw(s, @)] — Ep (5,0) [f*(Tw(s, @))] (35.82) 
0 


where Ty : S x A —> R is a function parameterizd by w. The first expectation can be estimated 
using Dexp, as in behavior cloning, and the second can be estimated using trajectories generated by 
13policy m. Furthermore, to implement this algorithm, we often use a parametric policy representation 
1479, and then perform stochastic gradient updates to find a saddle-point to Equation (35.82). 

15 With different choices of the convex function f, we can obtain many existing IL algorithms, 
46such as generative adversarial imitation learning (GAIL) [HE16b] and adversarial inverse 
“RL (AIRL) [FLL18], as well as new algorithms like f-Divergence Max-Ent IRL (f-MAX) and 
48forward adversarial inverse RL (FAIRL) [GZG19; Ke+19b]. 

19 Finally, the algorithms above typically require running the learned policy 7 to approximate the 
20 second expectation in Equation (35.82). In risk- or cost-sensitive scenarios, collecting more data is not 
21 always possible, Instead, we are in the off-policy IL setting, working with trajectories collected by some 
22 policy other than 7. Hence, we need to correct the mismatch between p? and the off-policy trajectory 
23 distribution, for which techniques from Section 35.5 can be used. An example is ValueDICE [KNT20], 


24which uses a similar distribution correction method of DualDICE (Section 35.5.2). 
25 
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3 6 Causality 


This chapter was written by Victor Veitch and Alex D’Amour. 


36.1 Introduction 


The bulk of machine learning considers relationships between observed variables with the goal of 
summarizing these relationships in a manner that allows predictions on similar data. However, for 
many problems, our main interest is to predict how system would change if it were observed under 
different conditions. For instance, in healthcare, we are interested in whether a patient will recover 
if given a certain treatment (as opposed to whether treatment and recovery are associated in the 
observed data). Causal inference addresses how to formalize such problems, determine whether 
they can be solved, and, if so, how to solve them. This chapter covers the fundamentals of this 
subject. Code examples for the discussed methods are available at https: //github.com/vveitch/ 
causality-tutorials. For more information on the connections between ML and causal inference, 
see e.g., [Kad+22; Xia+21a]. 

To make the gap between observed data modeling and causal inference concrete, consider the 
relationships depicted in Figure 36.la and Figure 36.1b. Figure 36.1a shows the relationship between 
deaths by drowning and ice cream production in the United States in 1931 (the pattern holds across 
most years). Figure 36.1b shows the relationship between smoking and lung cancer across various 
countries. In each case, there is a strong positive association. Faced with this association, we might 
ask: could we reduce drowning deaths by banning ice cream? Could we reduce lung cancer by 
banning cigarettes? We intuitively understand that these interventional questions have different 
answers, despite the fact that the observed associations are similar. Determining the causal effect of 
some intervention in the world requires some such causal hypothesis about the world. 

For concreteness, consider three possible explanations for the association between ice cream and 
drowning. Perhaps eating ice cream does cause people to drown—due to stomach cramps or similar. 
Or, perhaps, drownings increase demand for ice cream—the survivors eat huge quantities of ice cream 
to handle their grief. Or, the association may be due (at least in part) to a common cause: warm 
weather makes people more likely to eat ice cream and more likely to go swimming (and, hence, to 
drown). Under all three scenarios, we can observe exactly the same data, but the implications for 
an ice cream ban are very different. Hence, answering questions about what will happen under an 
intervention requires us to incorporate some causal knowledge of the world—e.g., which of these 
scenarios is plausible? 

Our goal in this chapter to introduce the essentials of estimating causal effects. The high-level 
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14 Figure 86.1: Correlation is not causation. (a) Ice cream production is strongly associated with deaths 
15by drowning. Ice cream production data from the US Department of Agriculture National Agricultural 
16 Statistics Service. Drowning data from the National Center for Health Statistics at the United States Centers 
17 for Disease Control. (b) Smoking is strongly associated with lung cancer. From ourworldindata. org/ 
18 smoking-big-problem-in-brief. Used with kind permission of Max Roser. 

19 

20 

21 


93 ¢Pproach has three steps. 


e Causal Estimands: The first step is to formally define the quantities we want to estimate. 


These are summaries of how the world would change under intervention, rather than summaries 
F of the world as it has already been observed. E.g., we want to formalize “The expected number of 
7a drownings in the United States if we ban ice cream”. 

29 


30 @ Identification: The next step is to identify the causal estimands with quantities that can, in 
31 principle, be estimated from observational data. This step involves codifying our causal knowledge 
32 of the world and translating this into a statement such as, “The causal effect is equal to the 
33 expected number of drownings after adjusting for month”. This step tells us what causal questions 
34 Wwe could answer with perfect knowledge of the observed data distribution. 


36 e Estimation: Finally, we must estimate the observable quantity using a finite data sample. The 
37 form of causal estimands favors certain efficient estimation procedures that allow us to exploit 
38 non-parametric (e.g., machine learning) predictive models. 


40 In this chapter, we'll mainly focus on the estimation of the causal effect of an intervention averaged 
41over all members of a population, known as the Average Treatment Effect or ATE. This is the 
42most common problem in applied causal inference work. It is in some sense the simplest problem, 
43 and will allow us to concretely explain the use and importance of the fundamental causal concepts. 
44 These causal concepts include structural causal models, causal graphical models, the do-calculus, and 
45 efficient estimation using influence function techniques. This problem is also useful for understanding 
46the role that standard predictive modeling and machine learning play in estimating causal quantities. 
47 
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36.2. CAUSAL FORMALISM 


36.2 Causal Formalism 


In causal inference, the goal is to use data to learn about how the outcome in the world would change 
under intervention. In order to make such inferences, we must also make use of our causal knowledge 
of the world. This requires a formalism that lets us make the notion of intervention precise and lets 
us encode our causal knowledge as assumptions. 


36.2.1 Structural Causal Models 


Consider a setting in which we observe four variables from a population of people: A;, an indicator 
of whether or not person i smoked at a particular age, Y;, an indicator of whether or not person i 
developed lung cancer at a later age, H;, a “health consciousness” index that measures a person’s 
health-consciousness (perhaps constructed from a set of survey responses about attitudes toward 
health), and G;, an indicator for whether the person has a genetic predisposition towards cancer. 
Suppose we observe a dataset of these variables drawn independently and identically from a population, 


(Ai, Yi, Hi) ~ = P°>s where “obs” stands for “observed”. 

In standard practice, we model data like these using probabilistic models. Notably, there are many 
different ways to specify a probabilistic model for the same observed distribution. For example, we 
could write a probabilistic model for POS as 


Aw P(A) 
H|A ~ P®™(H|A) 
Y|A,H ( 
G|A,H,Y ~ Psa, H,Y) 


2 
y 
4 


This is a valid factorization, and sampling variables in this order would yield valid samples from the 
joint distribution P°’. However, this factorization does not map well to a mechanistic understanding 
of how these variables are causally related in the world. In particular, it is perhaps more plausible 
that health-consciousness H causally precedes smoking status A, since a person’s health-consciousness 
would influence their decision to smoke. 

These intuitions about causal ordering are intimately tied to the notion of intervention. Here, 
we will focus on a notion of intervention that can be represented in terms of “structural” models 
that describe mechanistic relationships between variables. The fundamental objects that we will 
reason about are structural causal models, or SCM’s. SCM’s resemble probabilistic models, but 
they encode additional assumptions (see also Section 4.7). Specifically, SCM’s serve two purposes: 
they describe a probabilistic model and they provide semantics for transforming the data-generating 
process through intervention. 

Formally, SCM’s describe a mechanistic data generating process with an ordered sequence of 
equations that resemble assignment operations in a program. Each variable in a system is determined 
by combining other modeled variables (the causes) with exogenous “noise” according to some 
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i 

2 (unknown) deterministic function. For instance, a plausible SCM for P°?S might be 

3 

1 Gerk (36.5) 
5 H+ ful) (36.6) 
6 A+ fa(H, &) (36.7) 
, Ye fr(G.5,4,8) (36.8) 
9 


where the (unknown) functions f are fixed, and the variables € are unmeasured causes, modeled 
10 as independent random “noise” variables. Conceptually, the functions fo, fu, fa, fy describe deter- 
1 ministic physical relationships in the real world, while the variables € are hidden causes that are 
12 sufficient to distinguish each unit i in the population. Because we assume that each observed unit i 
48is drawn at random from the population, we model € as random noise. 

14 SCM’s imply probabilistic models, but not the other way around. For example, our example SCM 
45implies probabilistic model for the observed data based on the factorization P°™(G, H, A, Y) = 
16 pobs(G) Pobs (7) P(A | H)P°>s(Y | A, H). Thus, we could sample from the SCM in the same way 
/7we would from a probabilistic model: draw a set of noise variables € and evaluate each assignment 
48 operation in the SCM in order. 

12 Beyond the probabilistic model, an SCM encodes additional assumptions about the effects of 
22interventions. This can be formalized using the do calculus (as in the verb “to do”), which 
21we describe in Section 36.8; But in brief, interventions are represented by replacing assignment 
22 statements. For example, if we were interested in the distribution of Y in the hypothetical scenario 
23that smoking were eliminated, we could set the second line of the SCM to be A + 0. We would 
24 denote this by P(Y|do(A = 0), H). Because the f functions in the SCM are assumed to be invariant 
25 mechanistic relationships, the SCM encodes the assumption that this edited SCM generates data that 
26we would see if we really applied this intervention in the world. Thus, the ordering of statements 
27in an SCM are load-bearing: they imply substantive assumptions about how the world changes in 
28response to interventions. This is in contrast to more standard probabilistic models where variables 
29can be rearranged by applications of Bayes Rule without changing the substantive implications of 
3%the model. (See also Section 4.7.3.) 

31 We note that structural causal model may not incorporate all possible notions of causality. For 
32example, laws based on conserved quantities or equilibria—e.g., the ideal gas law—do not trivially 
33 map to SCMs, though these are fundamental in disciplines such as physics and economics. Nonetheless, 


34we will confine our discussion to SCMs. 
35 


36.2.2 Causal DAGs 


38 SCM’s encode many details about the assumed generative process of a system, but often it is useful 
39to reason about causal problems at a higher level of abstraction. In particular, it is often useful 
40to separate the causal structure of a problem from the particular functional form of those causal 
41relationships. Causal graphs provide this level of abstraction. A causal graph specifies which 
42 variables causally affect other variables, but leaves the parametric form of the structural equations f 
43 unspecified. Given an SCM, the corresponding causal graph can be drawn as follows: for each line 
44of the SCM, draw arrows from the variables on the right hand side to variables on the left hand 
45side. The causal DAG for our smoking-cancer example is shown in Figure 36.2. In this way, causal 
46 DAGs are related to SCMs in the same way that probabilistic graphical models (PGMs) are related 
47 
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36.2. CAUSAL FORMALISM 


(a) (b) 


Figure 36.2: (a) Causal graph illustrating relationships between smoking A, cancer Y, health conciousness 
H, and genetic cancer pre-disposition G. (b) “Mutilated” causal graph illustrating relationships under an 
intervention on smoking A. 


to probabilistic models. 

In fact, in the same way that SCMs imply a probabilistic model, causal DAGs imply a PGM. 
Functionally, causal graphs behave as probabilistic graphical models (Chapter 4). They imply 
conditional independence relationships between the variables in the observed data in same way. They 
obey the Markov property: If X + Y > Z then X IL Z|Y; recall d-separation (Section 4.2.4.1). 
Additionally, if X > Y + Z then, usually, X 4 Z|Y (even if X and Z are marginally independent). 
In this case, Y is called a collider for X and Z. 

Conceptually, the difference between causal DAGs and PGMs is that probabilistic graphical models 
encode our assumptions about statistical relationships, whereas causal graphs encode our (stronger) 
assumptions about causal relationships. Such causal relationships can be used to derive how statistical 
relationships would change under intervention. 

Causal graphs also allow us to reason about the causal and non-causal origins of statistical 
dependencies in observed data without specifying a full SCM. In a causal graph, two variables—say, 
A and D—can be statistically associated in different ways. First, there can be a directed path from 
(ancestor) A to (descendant) D. In this case, A is a causal ancestor of D and interventions on A will 
propagate through to change D; P(D|do(A = a)) #4 P(D|do(A = a’)). For example, smoking is a 
causal ancestor of cancer in our example. Alternatively, A and D could share a common cause—there 
is some variable C such that there is a directed path from C to A and from C to D. If A and D 
are associated only through such a path then interventions on A will not change the distribution of 
D. However, it is still the case that P(D|A = a) 4 P(D|A = a’)—observing different values of A 
changes our guess for the value of D. The reason is that A carries information about C, which carries 
information about D. For example, suppose we lived in a world where there was no effect of smoking 
on developing cancer (e.g., everybody vapes), there would nevertheless be an association between 
smoking and cancer because of the path A + H — Y. The existence of such “backdoor paths” is one 
core reason that statistical and causal association are not the same. Of course, more complicated 
variants of these associations are possible—e.g., C is itself only associated with A through a backdoor 
path—but this already captures the key distinction between causal and non-causal paths. 

Recall that our aim in introducing SCMs and causal graphs is to enable us to formalize our causal 
knowledge of the world and to make precise what interventional quantities we’d like to estimate. 
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Writing down a causal graph gives a simple formal way to encode our knowledge of the causal 
structure of a problem. Usefully, this causal structure is sufficient to directly reason about the 
implications of interventions without fully specifying the underlying SCM. The key observation is that 
if a variable A is intervened on then, after intervention, none of the other variables are causes of A. 
That is, when we replace a line of an SCM with a statement directly assigning a variable a particular 
value, we cut off all dependencies that variable had on its causal parents. Accordingly, in the causal 
graph, the intervened on variable has no parents. This leads us to the graph surgery notion of 
intervention: an intervention that sets A to a is the operation that deletes all incoming edges to A in 
iothe graph, and then conditions on A = a in the resulting probability distribution (which is defined 
i1by the conditional independence structure of the post-surgery graph). We’ll use Pearl’s do notation 
igto denote this operation. P(X|do(A = a)) is the distribution of X given A = a under the mutilated 
i3graph that results from deleting all edges going into A. Similarly, E[X|do(A = a)] ê Ep(xļao(4=a)) [X]. 
14 Thus, we can formalize statements such as “The average effect of receiving drug A” as 
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a ATE = E[Y|do(A = 1)| — E[Y|do(A = 0)], (36.9) 
i7where ATE stands for Average Treatment Effect. 

ig For concreteness, consider our running example. We contrast the distribution that results by 
ig9conditioning on A with the distribution that results from intervening on A: 

a P(Y, H,G|A =a) = P(Y|H,G, A = a)P(G)P(H|A = a) (36.10) 
22 P(Y, H, G|do(A = a)) = P(Y |H, G, A = a)P(G)P (H) (36.11) 


23 The key difference between these two distributions is that the standard conditional distribution 
24 describes a population where health consciousness H has the distribution that we observe among 
25 individuals with smoking status A = a, while the interventional distribution described a population 
26 where health consciousness H follows the marginal distribution among all individuals. For example, 
2Twe would expect P(H | A = smoker) to put more mass on lower values of H than the marginal 
28health consciousness distribution than the marginal distribution P(H), which would also include 
29non-smokers. The intervention distribution thus incorporates a hypothesis of how smoking would 
30 affect the subpopulation individuals who tend to be too health conscious to smoke in the observed 


31 data. 
32 


= 36.2.3 Identification 


35 A central challenge in causal inference is that many different SCM’s can produce identical distributions 
36of observed data. This means that, on the basis of observed data alone, we cannot uniquely identify 
37the SCM that generated it. This is true no matter how large of a data sample is available to us. 

38 For example, consider the setting where there is a treatment A that may or may not have an 
39effect on outcome Y, and where both the treatment and outcome are known to be affected by 
40some unobserved common binary cause U. Now, we might be interested in the causal estimand 
41 &[Y|do(A = 1)]. In general, we can’t learn this quantity from the observed data. The problem 
42is that, we can’t tell apart the case where the treatment has a strong effect from the case where 
43the treatment has no effect, but U = 1 both causes people to tend to be treated and increases the 
44 probability of a positive outcome. The same observation shows we can’t learn the (more complicated) 
45 interventional distribution P(Y |do(A = 1)) (if we could learn this, then we’d get the average effect 
46 automatically). 

47 
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36.2. CAUSAL FORMALISM 


Thus, an important part of causal inference is to augment the observed data with knowledge about 
the underlying causal structure of the process under consideration. Often, these assumptions can 
narrow the space of SCM’s sufficiently so that there is only one value of the causal estimand that is 
compatible with the observed data. We say that the causal estimand is identified or identifiable 
under a given set of assumptions if those assumptions are sufficient to provide a unique answer. 
There are many different sets of sufficient conditions that yield identifiable causal effects; we call 
each set of sufficient conditions an identification strategy. 

Given a set of assumptions about the underlying SCM, the most common way to show that a 
causal estimand is identified is by construction. Specifically, if the causal estimand can be written 
entirely in terms of observable probability distributions, then it is identifed. We call such a function of 
observed distributions a statistical estimand. Once such a statistical estimand has been recovered, 
we can then construct and analyze an estimator for that quantity using standard statistical tools. 
As an example of a statistical estimand, in the SCM above, it can be shown the ATE as defined in 
Equation (36.9), is equal to the following statistical estimand 


ATE Ë ;ATE 4 E/E[Y|H, A = 1] - E[Y |H, A = O]], (36.12) 


where the equality (*) only holds because of some specific properties of the SCM. Note that the RHS 
above only involves conditional expectations between observed variables (there are no do operators), 
so TATE is only a function of observable probability distributions. 

There are many kinds of assumptions we might make about the SCM governing the process under 
consideration. For example, the following are assertions we might make about the system in our 


running example: 


1. The probability of developing cancer is additive on the logit scale in A, Œ, and H (i.e., logistic 
regression is a well-specified model). 


2. For each individual, smoking can never decrease the probability of developing cancer. 


3. Whether someone smokes is influenced by their health consciousness H, but not by their genetic 
predisposition to cancer G. 


These assumptions range from strong parametric assumptions fully specifying the form of the SCM 
equations, to non-parametric assumptions that only specify what the inputs to each equation are, 
leaving the form fully unspecified. Typically, assumptions that fully specify the parametric form are 
very strong, and would require far more detailed knowledge of the system under consideration than 
we actually have. The goal in identification arguments is to find a set of assumptions that are weak 
enough that they might be plausibly true for the system under consideration, but which are also 
strong enough to allow for identification of the causal effect. 

If we are not willing to make any assumptions about the functional form of the SCM, then our 
assumptions are just about which variables affect (and do not affect) the other variables. In this sense, 
such which-affects-which assumptions are minimal. These assumptions are exactly the assumptions 
captured by writing down a (possibly incomplete) causal DAG, showing which variables are parents 
of each other variable. The graph may be incomplete because we may not know whether each possible 
edge is present in the physical system. For example, we might be unsure whether the gene G actually 
has a causal effect on health consciousness H. It is natural to ask to what extent we can identify 
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causal effects only on the basis of partially specified causal DAGs. It turns out much progress can be 
made based on such non-parametric assumptions; we discuss this in detail in Section 36.8. 

We will also discuss certain assumptions that cannot be encoded in a causal graph, but that are 
still weaker than assuming that full functional forms are known. For example, we might assume that 
the outcome is affected additively by the treatment and any confounders, with no interaction terms 
between them. These weaker assumptions can enable causal identification even when assuming the 
causal graph alone does not. 

It is worth emphasizing that every causal identification strategy relies on assumptions that have 
198ome content that cannot be validated in the observed data. This follows directly from the ill-posedness 
110f causal problems: if the assumptions used to identify causal quantities could be validated, that 
i2would imply that the causal estimand was identifiable from the observed data alone. However, since 
13we know that there are many values of the causal estimand that are compatible with observed data, 
i4it follows that the assumptions in our identification strategy must have unobservable implications. 
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g Structural causal models let us formalize and study a hierarchy of different kinds of query about the 
g98ystem under consideration. The most familiar is observational queries: questions that are purely 
>, about statistical associations (e.g., “Are smoking and lung cancer associated in the population this 
92sample was drawn from?”). Next is interventional queries: questions about causal relationships at 
93 the population level (e.g., “How much does smoking increase the probability of cancer in a given 
94population?”). The rest of this chapter is focused on the defintion, identification, and estimation of 
95interventional queries. Finally, there are counterfactual queries: questions about causal relationships 
9¢at the level of specific individuals, had something been different (e.g., “Would Alice have developed 
97cancer had she not smoked?”). This causal hierarchy was popularized by [Pea09a, Ch. 1]. 

93 Imteventional queries concern the prospective effect of an intervention on an outcome; for 
99example, if we intervene and prevent a randomly sampled individual from smoking, what is the 
39 probability they develop lung cancer? Ultimately, the probability statement here is about our 
31 uncertainity about the “noise” variables € in the SCM. These are the unmeasured factors specific to 
3 the randomly selected individual. The distribution is determined by the population from which that 
33 individual is sampled. Thus, interventional queries are statements about populations. Interventional 
34 queries can be written in terms of conditional distributions using do-notation, e.g. P(Y|do(A = 0)). 
35 [n our example, this represents the distribution of lung cancer outcomes for an individual selected at 


36random and prevented from smoking. 

37 Counterfactual queries concern how an observed outcome might have been different had an 
3g intervention been applied in the past. Counterfactual queries are often framed in terms of attributing 
3ga given outcome to a particular cause. For example, would Alice have developed cancer had she not 
4osmoked? Did most smokers with lung cancer develop cancer because they smoked? Counterfactual 
41 queries are so called because they require a comparison of counterfactual outcomes within individuals. 
42 In the formalism of SCM’s, counterfactual outcomes for an individual i are generated by running the 
43same values of €; through differently intervened SCM’s. Counterfactual outcomes are often written 
44in terms of potential outcomes notation. In our running smoking example, this would look like: 


45 
46 Yi(la) = fy (Gi, Hi, a, &,i). (36.13) 
47 
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That is, Y;(a) is the outcome we would have seen had A been set to a while all of G;, Hi, €3,; were 
kept fixed. 

It is important to understand what distinguishes interventional and fundamentally counterfactual 
queries. Just because a query can be written in terms of potential outcomes does not make it a 
counterfactual query. For example, the average treatment effect, which is the canonical interventional 
query, is easy to write in potential outcomes notation: 


ATE = E[Y;(1) — Y;(0)]. (36.14) 


Instead, the key dividing line between counterfactual and interventional queries is whether the query 
requires knowing the joint distribution of potential outcomes within individuals, or whether marginal 
distributions of potential outcomes across individuals will suffice. An important signature of a 
counterfactual query is conditioning on the value of one potential outcome. For example, “the lung 
cancer rate among smokers who developed cancer, had they not smoked” is a counterfactual query, 
and can be written as: 


s[¥4(0) | ¥i(1) = 1, Ai = 1] (36.15) 


Answering this query requires knowing how individual-level cancer outcomes are related (through 
€3,;) across the worlds where the each individual i did and did not smoke. Notably, this query cannot 
be rewritten using do-notation, because it requires a distinction between Y(0) and Y(1) while the 
ATE can: E[Y | do(A = 1)] — E[Y | do(A = 0)]. 

Counterfactual queries require categorically more assumptions for identification than interventional 
ones. For identifying interventional queries, knowing the DAG structure of an SCM is often sufficient, 
while for counterfactual queries, some assumptions about the functional forms in the SCM are 
necessary. This is because only one potential outcome is ever observed for each individual, so the 
dependence between potential outcomes within individuals is not observable. For example, the data 
in our running example provide no information on how individual-level smoking and non-smoking 
cancer risk are related. Thus, answering a question like “Did smokers who developed cancer have lower 
non-smoking cancer risk than smokers who did not develop cancer?”, requires additional assumptions 
about how characteristics encoded in €; are translated to cancer outcomes. To answer this question 
without such assumptions, we would need to observe smokers who developed cancer in the alternate 
world where they did not smoke. Because they compare how individuals would have turned out under 
different generating processes, counterfactual queries are often referred to as “cross-world” quantities. 
On the other hand, interventional queries only require understanding the marginal distributions of 
potential outcomes Y;(0) and Y;(1) across individuals; thus, no cross-world information is necessary 
at the individual level. 

We conclude this section by noting that counterfactual outcomes and potential outcomes notation 
are often conceptually useful, even if they are not used to explicitly answer counterfactual queries. 
Many causal queries are more intuitive to formalize in terms of potential outcomes. E.g., “Would I 
have smoked if I was more health conscious?” may be more intuitive than “Would a randomly sampled 
individual from the same population have smoked had they been subject to an intervention that made 
them more health concious?”. In fact, some schools of causal inference use potential outcomes, rather 
than DAGs, as their primary conceptual building block [See IR15]. Causal graphs and potential 
outcomes both provide ways to formulate interventional queries and causal assumptions. Ultimately, 
these are mathematically equivalent. Nevertheless, practically, they have different strengths. The 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1182 


main advantage of potential outcomes is that counterfactual statements often map more directly to 
our mechanistic understanding of the world. This can make it easier to articulate causal desiderata 
and causal assumptions we may wish to use. On the other hand, the potential outcomes notation 
does not automatically distinguish between interventional and counterfactual queries. Additionally, 
causal graphs often give an intuitive and easy way of articulating assumptions about structural 
causal models involving many variables—potential outcomes get quickly unwieldly. In short: both 
formalizations have distinct advantages, and those advantages are simply about how easy it is to 
translate our causal understanding of the world into crisp mathematical assumptions. 
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13We now turn to the business of estimating causal effects from data. We begin with randomized 
14control trials, which are experiments designed to make the causal concerns as simple as possible. 
15 The simplest situation for causal estimation is when there are no common causes of A and Y. The 
16 world is rarely so obliging as to make this the case. However, sometimes we can design an experiment 
17to enforce the no-common-causes structure. In randomized control trials we assign each participant 
18to either the treatment or control group at random. Because random assignment does not depend on 
19any property of the units in the study, there are no causes of treatment assignment, and hence also 
20no common causes of Y and A. 

21 Jn this case, it’s straightforward to see that P(Y|do(A = a) = P(Y|a). This is essentially by 
22 definition of the graph surgery: since A has no parents, the mutilated graph is the same as the original 
23 graph. Indeed, the graph surgery definition is chosen to make this true: any sensible formalization of 
24causality should have this identification result. 


25 It is common to use RCTs to study the average treatment effect, 

26 

27 ATE = E[Y|do(A = 1)] — E[Y |do(A = 0)]. (36.16) 
28 

29 This is the expected difference between being assigned treatment and assigned no treatment for a 
30randomly chosen member of the population. It’s easy to see that in an RCT this causal quantity is 
31 identified as a parameter TROT of the observational distribution: 


32 
ag TT = E[Y]A = 1] - E[Y]A = 0]. 


4 i : 
35 Then, a natural estimator is: 


35 

2 d 

37 E n- E Y;, (36.17) 
38 A A1 i:A;=0 


“where na is the number of units who received treatment. That is, we estimate the average treatment 
4 effect as the difference between oa average outcome of the treated group and the average outcome of 
T the untreated (control) group. 

~ Randomized control trials are the gold standard for estimating causal effects. This is because we 
T know by design that there are no confounders that can produce alternative causal explanations of the 


451, There is a literature on efficient estimation of causal effects in RCT’s going back to Fisher [Fis25] that employ more 
46 sophisticated estimators. See also Lin [Lin13a] and Bloniarz et al. [Blo+16] for more modern treatments. 


47 
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36.4. CONFOUNDER ADJUSTMENT 


Figure 36.3: A causal DAG illustrating a situation where treatment A and outcome Y are both influenced by 
observed confounders X. 


data. In particular, the assumption of the triangle DAG—there are no unobserved confounders—is 
enforced by design. However, there are limitations. Most obviously, randomized control trials are 
sometimes infeasible to conduct. This could be due to expense, regulatory restrictions, or more 
fundamental difficulties (e.g., in developmental economics, the response of interest is sometimes 
collected decades after treatment). Additionally, it may be difficult to ensure that the participants in 
an RCT are representative of the population where the treatment will be deployed. For instance, 
participants in drug trials may skew younger and poorer than the population of patients who will 
ultimately take the drug. 


36.4 Confounder Adjustment 


We now turn to the problem of estimating causal effects using observational (i.e., not experimental) 
data. The most common application of causal inference is estimating the average treatment effect 
(ATE) of an intervention. The ATE is also commonly called the average causal effect, or ACE. 
Here, we focus on the important special case where the treatment A is binary, and we observe the 
outcome Y as well as a set of common causes X that influence both A and Y. 


36.4.1 Causal Estimand, Statistical Estimand, and Identification 


Consider a problem where we observe treatment A, outcome Y, and covariates X, which are drawn 
i.i.d. from some unknown distribution P. We wish to learn the average treatment effect: the expected 
difference between being assigned treatment and assigned no treatment for a randomly chosen member 
of the population. Following the discussion in the introduction, there are three steps to learning this 
quantity: mathematically formalize the causal estimand, give conditions for the causal estimand to 
be identified as a statistical estimand, and, finally, estimate this statistical estimand from data. We 
now turn to the first two steps. 

The average treatment effect is defined to be the difference between the average outcome if we 
intervened and set A to be 0, versus the average outcome if we intervented and set A to be 1. Using 
the do notation, we can write this formally as 


ATE = E[Y|do(A = 1)] — E[Y|do(A = 0)]. (36.18) 
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The next step is to articulate sufficient conditions for the ATE to be identified as a statistical 
estimand (a parameter of distribution P). The key issue is the possible presence of confounders. 
Confounders are “common cause” variables that affect both the treatment and outcome. When there 
are confounding variables in observed data, the sub-population of people who are observed to have 
received one level of the treatment A will differ from the rest of the population in ways that are 
relevant to their observed Y. For example, there is a strong positive association between horseback 
riding in childhood (treatment) and healthiness as an adult (outcome) [RB16]. However, both of these 
quantities are influenced by wealth X. The population of people who rode horses as children (A = 1) 
i9is wealthier than the population of people who did not. Accordingly, horseback-riding population 
11 Will have better health outcomes even if there is no actual causal benefit of horseback riding for adult 
iz health. 

13 We'll express the assumptions required for causal identification in the form of a causal DAG. 
14Namely, we consider the simple triangle DAG in Figure 36.3, where the treatment and outcome 
isare influenced by observed confounders X. It turns out that the assumption encoded by this DAG 
ig Suffices for identification. To understand why this is so, recall that the target causal effect is defined 
izaccording to the distribution we would see if the edge from X to A was removed (that’s the meaning 
igof do). The key insight is that because the intervention only modifies the relationship between X 
igand A, the structural equation that generates outcomes Y given X and A, illustrated in Figure 36.3 
g0as the A > Y + X, is the same even after the X — Y edge is removed. For example, we might 
21 believe that the physiological processes by which smoking status A and confounders X produce 
22lung cancer Y remain the same, regardless of how the decision to smoke or not smoke was made. 
23 Secondly, because the intervention does not change the composition of the population, we would also 
24€xpect the distribution of background characteristics X to be the same between the observational 
25and intervened processes. 

26 With these insights about invariances between observed and interventional data, we can derive a 
27Statistical estimand for the ATE as follows. 
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s Theorem 2 (Adjustment with No Unobserved Confounders). We observe A,Y,X ~ P. Suppose 
~ that 
30 


317, (Confounders observed) The data obeys the causal structure in Figure 36.3. In particular, X 


32 — contains all common causes of A and Y and no variable in X is caused by A or Y. 
33 


342. (Overlap) 0 < P(A = 1|X = x) < 1 for all values of x. That is, there are no individuals for whom 
35 treatment is always or never assigned. 


36 

37 Then, the average treatment effect is identified as ATE = 1, where 

7 7 = E(EY|A=1,X]] - E[E[Y |A = 0, X]]. (36.19) 
40 


Proof. First, we expand the ATE using the tower property of expectation, conditioning on X. Then, 


Owe apply the invariances discussed above: 

42 

43 ATE =E[Y|do(A = 1)] — E[Y|do(A = 0)] (36.20) 
< = E[E[Y|do(A = 1), X]] — E[E[Y|do(A = 0), X]| (36.21) 
46 = E[E[Y|A = 1, X]] — E[E[Y|A = 0, X]] (36.22) 
47 
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The final equality is the key to passing from a causal to observational quantity. This follows because, 
from the causal graph, the conditional distribution of Y given A, X is the same in both the original 
graph and in the mutilated graph created by removing the edge from X to A. This mutilated graph 
defines P(Y |do(A = 1), X), so the equality holds. 

The condition that 0 < P(A = 1|X = 2) < 1 is required for the first equality (the tower property) 
to be well defined. 


Note that Equation (36.19) is a function of only conditional expectations and distributions that 
appear in the observed data distribution (in particular, it contains no “do” operators). Thus, if we 
can fully characterize the observed data distribution P, we can map that distribution to a unique 
ATE. 

It is useful to note how 7 differs from the naive estimand E[Y|A = 1] — E[Y|A = 0] that just 
reports the treatment-outcome association without adjusting for confounding. The comparison is 
especially clear when we write out the outer expectation in 7 explicitly as an integral over X: 


T= J mY | A = 1, X]P(X)dX — J [Y | A = 0, X]P(X)dX (36.23) 


We can write the naive estimand in a similar form by applying the tower property of expectation: 


[Y | A =1]-E[Y | A =0] =| [Y | A=1, X]P(X | A=wax- | [Y | A= 0, X]P(X | A = 0)dX 
(36.24) 


The key difference is the probability distribuiton over X that is being integrated over. The observa- 
tional difference in means integrates over the distinct conditional distributions of confounders X, 
depending on the value of A. On the other hand, in the ATE estimand 7, we integrate over the same 
distribution P(X) for both levels of the treatment. 


Overlap In addition to the assumption on the causal structure, identification requires that there is 
sufficient random variation in how treatments are assigned. 


Definition 1. A distribution P on A,X satisfies overlap if 0 < P(A = 1|x) < 1 for all x. It 
satisfies strict overlap if € < P(A = 1|x) < 1 — «€ for all x and some € > 0. 


Overlap is the requirement that any unit could have either recieved the treatment or not. 

To see the necessity of overlap, consider estimating the effectiveness of a drug in a study where 
patient sex is a confounder, but the drug was only ever prescribed to male patients. Then, conditional 
on a patient being female, we would know that patient was assigned to control. Without further 
assumptions, it’s impossible to know the effect of the drug on a population with female patients, 
because there would be no data to inform the expected outcome for treated female patients, that 
is, E[Y | A = 1, X = female]. In this case, the statistical estimand equation 36.19 would not be 


identifiable. In ihe same vein, strict overlap ensures that the conditional distributions at each stratum 
of X can be estimated in finite samples. 

Overlap can be particularly limiting in settings where we are adjusting for a large number of 
covariates (in an effort to satisfy no unobserved confounding). Then, certain combinations of traits 
may be very highly predictive of treatment assignment, even if individual traits are not. E.g., male 
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patients over age 70 with BMI greater than 25 are very rarely assigned the drug. If such groups 
represent a significant fraction of the target population, or have significantly different treatment 
effects, then this issue can be problematic. In this case, the strict overlap assumption puts very strong 
restrictions on observational studies: for an observational study to satisfy overlap, most dimensions 
of the confounders X would need to closely mimic the balance we would expect in an RCT [D’A+21]. 


36.4.2 ATE Estimation with Observed Confounders 
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o We now return to estimating the ATE using observed—i.e., not experimental—data. We’ve shown 
,, that in the case where we observe all common causes of the treatment and outcome, the ATE is 
causally identified with a statistical estimand 7. We now consider several strategies for estimating this 
quantity using a finite data sample. Broadly, these techniques are known as backdoor adjustment.” 

Recall that the defining characteristic of a confounding variable is that it affects both treatment 
and outcome. Thus, an adjustment strategy may aim to account for the influence of confounders on 
the observed outcome, the influence of confounders on treatment, or both. We discuss each of these 
strategies in turn. 


& IS la le te le | IF 


1936.4.2.1 Outcome Model Adjustment 


22 We begin with an approach to covariate adjustment that relies on modeling the conditional expectation 
21 of the outcome Y given treatment A and confounders X. This strategy is often referred to as g- 
22 computation or outcome adjustment. To begin, we define 

23 

24Definition 2. The conditional expected outcome is the function Q given by 

25 

2 Qla, £) =E[Y|A =a, X =a]. (36.25) 


og Substituting this definition into the definition of our estimand 7, Equation (36.19), we have 


297 = E[Q(1,x) — Q(0, x)]. This suggests a procedure for estimating 7: fit a model Q for Q and then 
39 report 

31 1 A : 

za = a SC ÂC, zi) — Q(0, 2). (36.26) 
33 i 


"To fit Â, recall that E[Y|a,a] = argming E[(Y — Q(A,X)*]. That is, the minimizer (among all 
gg functions) of the squared loss risk is the conditional expected outcome.* So, to approximate Q, we 
~ simply use mean squared error to fit a predictor that predicts Y from A and X. 

3 The estimation procedure takes several steps. We first fit a model Q to predict Y. Then, for each 
~ unit i, we predict that unit’s outcome had they received treatment Q(1,x;) and we predict their 
49 Outcome had they not received treatment Q(0,2;).° If the unit actually did receive treatment (a; = 1) 


419. As we discuss in Section 36.8, this backdoor adjustment references the estimand returned by the do-calculus to 
42 eliminate confounding from a backdoor path. This also generalizes the approaches discussed here to some cases where 
43 we do not observe all common causes. 

443. The “g” stands for generalized, for now-inscrutable historical reasons [Rob86]. 

—4. To be precise, this definition applies when X and Y are square-integrable, and the minimzation taken over measurable 
45 functions. 

465. this interpretation is justified by the same conditions as Theorem 2 


47 
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then Q(0, zi) is our guess about what would have happened in the counterfactual case that they 
did not. The estimated expected gain from treatment for this individual is ÂG, zi) — Â(0, zi)—the 
difference in expected outcome between being treated and not treated. Finally, we estimate the outer 
expectation with respect to P(X)—the true population distribution of the confounders—using the 
empirical distribution P(X) = 1/n $; ôz,. In effect, this means we substitute the expectation (over 


an unknown distribution) by an average over the observed data. 


Linear regression It’s worth saying something more about the special case where Q is modeled 
as a linear function of both the treatment and all the covariates. That is, the case where we assume 
the identification conditions of Theorem 2 and we additionally assume that the true, causal law 
(the SCM) governing Y yields: Q(A, X) = E[Y|A, X] = E[fy (A, X,)|4, X] = Bo + BaAt BxX. 
Plugging in, we see that Q(1, X) — Q(0, X) = 8,4 (and so also r = 8,4). Then, the estimator for the 
average treatment effect reduces to the estimator for the regression coefficient 64. This “fit linear 
regression and report the regression coefficient” remains a common way of estimating the association 
between two variables in practice. The expected-outcome-adjustment procedure here may be viewed 
as a generalization of this procedure that removes the linear parametric assumption. 


36.4.2.2 Propensity Score Adjustment 


Outcome model adjustment relies on modeling the relationship between the confounders and the 
outcome. A popular alternative is to model the relationship between the confounders and the 
treatment. This strategy adjusts for confounding by directly addressing sampling bias in the treated 
and control groups. This bias arises from the relationship between the confounders and the treatment. 
Intuitively, the effect of confounding may be viewed as due to the difference between P(X|A = 1) 
and P(X|A = 0)—e.g., the population of people who rode horses as children is wealthier than the 
population of people who did not. When we observe all confounding variables X, this degree of over- 
or under-representation can be adjusted away by reweighting samples such that the confounders X 
have the same distribution in the treated and control groups. When the confounders are balanced 
between the two groups, then any differences between them must be attributable to the treatment. 

A key quantity for balancing treatment and control groups is the propensity score, which 
summarises the relationship between confounders and treatment. 


Definition 3. The propensity score is the function g given by g(x) = P(A = 1|X = zx). 


To make use of the propensity score in adjustment, we first rewrite the estimand 7 in a suggestive 
form, leveraging the fact that A € {0,1}: 


YA  Y(1—A) 
g(X) 1-g(X) 


This identity can be verified by noting that E[Y A|X] = E[Y|A = 1, X]P(A = 1|X) + 0, rearranging 


|. (36.27) 


for E[Y|A = 1, X], doing the same for E[Y|A = 0, X], and substituting in to Equation (36.19). Note 
that the identity is just a mathematical fact about the statistical estimand—it does not rely on any 
causal assumptions, and holds whether or not 7 can be interpreted as a causal effect. 

This expression suggests the inverse probability of treatment weighted estimator, or IPTW 
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estimator: 


sIPTW A 1 y TA Ai Ai) (36.28) 
n — G(Xi) 1-G(Xi) 
Here, g is an estimate of the propensity score function. Recall from Section 14.2.1 that if a model is well- 
specified and the loss function is a proper scoring rule then risk minimizer g* = argmin, E[L(A, g(X))] 
will be g*(X) = P(A = 1|X). That is, we can estimate the propensity score by fitting a model that 
opredicts A from X. Cross-entropy and squared loss are both proper scoring rules, so we may use 
,, standard supervised learning methods. 
12 In summary, the procedure is to estimate the propensity score function (with machine learning), 
13and then to plug the estimated propensity scores 9(a;) into Equation (36.28). The IPT W estimator 
j4computes a difference of weighted averages between the treated and untreated group. The effect is to 
j5upweight the outcomes of units who were unlikely to be treated but who nevertheless actually, by 
jg chance, recieved treatment (and similarly for untreated). Intuitively, such units are typical for the 
j7untreated population. So, their outcomes under treatment are informative about what would have 
ig happened had a typical untreated unit received treatment. 
i9 A word of warning is in order. Although the IPTW is asymtotically valid and popular in practice, 
goit can be very unstable in finite samples. If estimated propensity scores are extreme for some values 
21 0f x (that is, very close to 0 or 1), then the corresponding IPTW weights can be very large, resulting 
ə2in a high-variance estimator. In some cases, this instability can be mitigated by instead using the 
23 Hajek version of the estimator. 


IO 100 IN ID Jo Ie IW IN Ie 


24 


ah 1/4(X:;) 1/(1-4(X:)) 
as șb-IPTW AN yA ON ONY (1 A; ; 36.29 
2 2 yi i/a) 2D l ISS 0-A 00) eee) 
27 


28 However, extreme weights can persist even after self-normalization, either because there are truly 
29strata of X where treatment assignment is highly imbalanced, or because the propensity score 
30estimation method has overfit. In such cases, it is common to apply heuristics such as weight clipping. 
31 See Khan and Ugander [KU21] for a longer discussion of inverse-propensity type estimators, 
32including some practical improvements. 

33 

~ 836.4.2.3 Double Machine Learning 

36 We have seen how to estimate the average treatment effect using either the relationship between 
37confounders and outcome, or the relationship between confounders and treatment. In each case, 
38we follow a two step estimation procedure. First, we fit models for the expected outcome or the 
39 propensity score. Second, we plug these fitted models into a downstream estimator of the effect. 

40 Unsurprisingly, the quality of the estimate of 7 depends on the quality of the estimates Q or g. This 
41is problematic because Q and g may be complex functions that require large numbers of samples to 
42estimate. Even though we’re only interested in the 1-dimensional parameter 7, the naive estimators 
43 described thus far can have very slow rates of convergence. This leads to unreliable inference or very 
44large confidence intervals. 

45 Remarkably, there are strategies for combining Q and g in estimators that, in principle, do better 
46than using either Q or g alone. The Augmented Inverse Probability of Treatment Weighted 
47 
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Estimator (AIPTW) is one such estimator. It is defined as 


(xi) 


a 1 — g(Xi) 


pAITW & 7 > Q(1, Xe) - QOX) + 4s (36.30) 


That is, 7“1’TW is the outcome adjustment estimator plus a stabilization term that depends on 
the propensity score. This estimator is a particular case of a broader class of estimators that are 
refered to as semi-parametrically efficient or double machine-learning estimators [Che+17e; 
Che+17d]. We’ll use the later terminology here. 

We now turn to understanding the sense in which double machine learning estimators are robust 
to misestimation of the nuisance functions Q and g. To this end, we define the influence curve 
of T to be the function ¢ defined by® 


(Xi A4 Yi Q, 9.7) È Q(X) - QO, X) + AEE A 


(36.31) 


By design, 747TW — 7 = "DD (Xi; Q, 4, T). We begin by considering what would happen if we 
simply knew Q and g, and didn’t have to estimate them. In this case, the estimator would be 
qideal = 15 , (Xi; Q, 9,7) and, by the central limit theorem, we would have: 


Jn(Fisel — 7) Ss Normal(0,E[4(Xi; Q, 9,7)?]). (36.32) 


This result characterizes the estimation uncertainity in the best possible case. If we knew Q and g, 
we could rely on this result for, e.g., finding confidence intervals for our estimate. 

The question is: what happens when Q and g need to be estimated? For general estimators and 
nuisance function models, we don’t expect the \/n-rate of Equation (36.32) to hold. For instance, 
Vn(t@ — T) only converges if /nE[(Q — Q)2]? —> 0. That is, for the naive estimator we only get the 
yn rate for estimating 7 if we can also estimate Q at the y/n rate—a much harder task! This is the 
issue that the double machine learning estimator helps with. 

To understand how, we decompose the error in estimating 7 as follows: 


Vn(PAPTW _ 7) (36.33) 


z = E (Kis Q,4,7) (36.34) 


J = Y (Kis, 9,7) — 6(KisQ, 9,7) — EAX: Â, 4,7) — 6% Q,9,7) (36.35) 


+ /nE[$(X; Q, 9,7) — (X3Q, 9,7)] (36.36) 


We recognize the first term, Equation (36.34), as /n(7'4e@! — 7), the estimation error in the optimal 
case where we know Q and g. Ideally, we’d like the error of 74!17TW to be asymptotically equal to 
this ideal case—which will happen if the other two terms go to 0. 


6. Influence curves are the foundation of what follows, and the key to generalizing the analysis beyond the ATE. 
Unfortunately, going into the general mathematics would require a major digression, so we omit it. However, see 
references at the end of the chapter for some pointers to the relevant literature. 
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The second term, Equation (36.35), is a penalty we pay for using the same data to estimate Q, g 
and to compute 7. For many model classes, it can be shown that such “empirical process” terms go 
to 0. This can also be guaranteed in general by using different data for fitting the nuisance functions 
and for computing the estimator (see the next section). 

The third term, Equation (36.36), captures the penalty we pay for misestimating the nuisance 
functions. This is where the particular form of the AIPTW is key. With a little algebra, we can show 
that 
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E(X; Â, 9) — o(%Q,9)] = AX) - 9(X))(Q(1, X) - Q01, X)) (36.37) 


Q 
N 
aa 


(A(X) — 9(X))(Q(0, X) - Q(0,X))]. (36.38) 


le [is lE le 


The important point is that estimation errors of Q and g are multiplied together. Using the Cauchy- 
15 Schwarz inequality, we find that /nE[¢(X; Â, 9) — (X; Q, g)] — 0 as long as J/nmax, [(Q(a, X)— 
Qla, X))?]2E[(g(X) — 9(X))?]2 — 0. That is, the misestimation penalty will vanish so long as the 
~ product of the misestimation errors is o(,/n). For example, this means that that 7 can be estimated at 
18 the (optimal) \/n rate even when the estimation error of each of Q and g only decreases as o(n~!/*). 
1 The upshot here is that the double machine learning estimator has the special property that the 


“weak condition /nE(Q(T, X) — Q(T, X))?E(g(X) — g(X))? > 0 suffices to imply that 


zJ 


n Vn(PAPTW _ 7) 4; Normal (0, Ele(X:; Q, 9,7)?]) (36.39) 
24(though strictly speaking this requires some additional technical conditions we haven’t discussed). 


25 This is not true for the earlier estimators we discussed, which require a much faster rate of convergence 
26 for the nuisance function estimation. 

27 The AIPTW estimator has two further nice properties that are worth mentioning. First, it is 
28 non-parametrically efficient. This means that this estimator has the smallest possible variance 
29 of any estimator that does not make parametric assumptions; namely, E[¢(X;;Q,9,7)?]. This means, 
30 for example, that this estimator yields the smallest confidence intervals of any approach that does not 
3! rely on parametric assumptions. Second, it is doubly robust: the estimator is consistent (converges 


32to the true 7 as n — oo) as long as at least one of either Q or g is consistent. 
33 


~ 36-4.2.4 Cross Fitting 


36 The term Equation (36.35) in the error decomposition above is the penalty we pay for reusing the 
37same data to both fit Q, g and to compute the estimator. For many choices of model for Q, g, this 
38term goes to 0 quickly as n gets large and we achieve the (best case) y/n error rate. However, this 
39 property doesn’t always hold. 

40 Asan alternative, we can always randomly split the available data and use one part for model fitting, 
4iand the other to compute the estimator. Effectively, this means the nuisance function estimation and 
42estimator computation are done using independent samples. It can then be shown that the reuse 
43 penalty will vanish. However, this comes at the price of reducing the amount of data available for 
44each of nuisance function estimation and estimator computation. 

45 This strategy can be improved upon by a cross fitting approach. We divide the data into K 
46folds. For each fold j we use the other K — 1 folds to fit the nuisance function models ogo gl. 
47 
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Then, for each datapoint i in fold j, we take Q(a;,2;) = Q~4(a;,a;) and g(2;) = 9-4 (a;). That is, 
the estimated conditional outcomes and propensity score for each datapoint are predictions from a 
model that was not trained on that datapoint. Then, we estimate T by plugging {Q(a;, x), g(a) }: 
into Equation (36.30). It can be shown that this cross fitting procedure has the same asymptotic 
guarantee—the central limit theorem at the yn rate—as described above. 


36.4.3 Uncertainity Quantification 


In addition to the point estimate 7 of the average treatment effect, we'd also like to report a measure 
of the uncertainity in our estimate. For example, in the form of a confidence interval. The asymptotic 
normality of yn? (Equation (36.39)) provides a means for this quantification. Namely, we could 
base confidence intervals and similar on the limiting variance E[¢(X;;Q, g,7)?]. Of course, we don’t 
actually know any of Q, g, or T. However, it turns out that it suffices to estimate the asymptotic 
variance with I5; (Xi; Q, 4, 7)? [Che+17e]. That is, we can estimate the uncertainity by simply 
plugging in our fitted nuisance models and our point estimate of 7 into 


A] = An) OXQ, 9,7)". (36.40) 


This estimated variance can then be used to compute confidence intervals in the usual manner. E.g., 


we'd report a 95% confidence interval for 7 as 7 + 1.96,/V[f]/n. 

Alternatively, we could quantify the uncertainity by bootstrapping. Note, however, that this would 
require refitting the nuisance functions with each bootstrap model. Depending on the model and 
data, this can be prohibitively computationally expensive. 


36.4.4 Matching 


One particularly popular approach to adjustment-based causal estimation is matching. Intuitively, 
the idea is to match each treated to unit to an untreated unit that has the same (or at least similar) 
values of the confounding variables and then compare the observed outcomes of the treated unit and 
its matched control. If we match on the full set of common causes, then the difference in outcomes is, 
intuitively, a noisy estimate of the effect the treatment had on that treated unit. We’ll now build 
this up a bit more carefully. In the process we’ll see that matching can be understood as, essentially, 
a particular kind of outcome model adjustment. 

For simplicity, consider the case where X is a discrete random variable. Define A, to be the set of 
treated units with covariate value x, and Cy to be the set of untreated units with covariate value zx. 
In this case, the matching estimator is: 


#matchin 
tching _ =e) ag Ee *- rap ho (36.41) 


where P(x) is an estimator of P(X = x)—e.g., the fraction of units with X = x. Now, we can rewrite 
Y; = Q(A;, Xi) + & where £; is a unit-specific noise term defined by the equation. In particular, we 
have that E[é;|A;, X;] = 0. Substituting this in, we have: 


amatching _ = 2a) Q(0, «)) + a. | Ži &; = Cl | 5 Ej- (36.42) 


jECz 
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We can recognize the first term as an estimator of usual target parameter r (it will be equal to 7 if 
P(x) = P(x)). The second term is a difference of averages of random variables with expectation 0, 
and so each term will converge to 0 as long as |.A;| and |C,,| each go to infinity as we see more and 
more data. Thus, we see that the matching estimator is a particular way of estimating the parameter 
T. The procedure can be extended to continuous covariates by introducing some notion of values of 
X being close, and then matching close treatment and control variables. 

There are two points we should emphasize here. First, notice that the argument here has nothing 
to do with causal identification. Matching is a particular technique for estimating the observational 
19 parameter T. Whether or not 7 can be interpreted as an average treatment effect is determined by 
ithe conditions of Theorem 2—the particular estimation strategy doesn’t say anything about this. 
12Second, notice that in essence matching amounts to a particular choice of model for Q. Namely, 
BQ(1, x“) = Taal Jea, Yi and similarly for Q(0, x). That is, we estimate the conditional expected 
14outcome as a sample mean over units with the same covariate value. Whether this is a good idea 
15depends on the quality of our model for Q. In situations where better models are possible (e.g., a 
i6machine-learning model fits the data well), we might expect to get a more accurate estimate by using 
i7the conditional expected outcome predictor directly. 
ig There is another important case we mention in passing. In general, when using adjustment based 
igidentification, it suffices to adjust for any function ¢(X) of X such that A L X|¢(X). To see that 
20 adjusting for only ¢(X) suffices, first notice that g(X) = P(A = 1|X) = P(A = 1|@(X)) only depends 


21on $(X), and then recall that can write the target parameter as T = es) — a, whence 
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227 only depends on X through g(X). That is: replacing X by a reduced version ¢(X) such that 
23 g(X) = P(A = 1|¢(X)) can’t make any difference to T. Indeed, the most popular choice of ¢(X) is 
24the propensity score itself, 6(X) = g(X). This leads to propensity score matching, a two step 
25 procedure where we first fit a model for the propensity score, and then run matching based on the 
26 estimated propensity score values for each unit. Again, this is just a particular estimation procedure 
27for the observational parameter T, and says nothing about whether it’s valid to interpret T as a 


28 causal effect. 
29 


36.4.5 Practical Considerations and Procedures 


32when performing causal analysis, many issues can arise in practice, some of which we discuss below. 
33 


~ 36-4.5-1 What to adjust for 


36 Choosing which variables to adjust for is a key detail in estimating causal effects using covariate 
37adjustment. The criterion is clear when one has a full causal graph relating A, Y, and all covariates 
38 X to each other. Namely, adjust for all variables that are actually causal parents of A and Y. In 
39fact, with access to the full graph, this criterion can be generalized somewhat—see Section 36.8. 

40 In practice, we often don’t actually know the full causal graph relating all of our variables. 
41 As a result, it is common to apply simple heuristics to determine which variables to adjust for. 
42 Unfortunately, these heuristics have serious limitations. However, exploring these is instructive. 

43 A key condition in Theorem 2 is that the covariates X that we adjust for must include all the 
44common causes. In the absence of a full causal graph, it is tempting to condition on as many observed 
45 variables as possible to try to ensure this condition holds. However, this can be problematic. For 
46instance, suppose that M is a mediator of the effect of A on Y—i.e., M lies on one of the directed 
47 
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Figure 86.4: The M-bias causal graph. Here, A and Y are not confounded. However, conditioning on the 
covariate X opens a backdoor path, passing through U; and Uz (because X is a colider). Thus, adjusting for 
X creates bias. This is true even though X need not be a pre-treatment variable. 


paths between A and Y. Then, conditioning on M will block this path, removing some of the causal 
effect. Note that this does not always result in an attenuated, or smaller-magnitude, effect estimate. 
The effect through a given mediator may run in the opposite direction of other causal pathways 
from the treatment; thus conditioning on a mediator can inflate or even flip the sign of a treatment 
effect. Alternatively, if C is a collider between A and Y—a variable that is caused by both—then 
conditioning on C will induce an extra statistical dependency between A and Y. 

Both pitfalls of the “condition on everything” heuristic discussed above both involve conditioning 
on variables that are downstream of the treatment A. A natural response is to this is to limit 
conditioning to all pre-treatment variables, or those that are causally upstream of the treatment. 
Importantly, if there is a valid adjustment set in the observed covariates X, then there will also be a 
valid adjustment set among the pre-treatment covariates. This is because any open backdoor path 
between A and Y must include a parent of A, and the set of pre-treatment covariates includes these 
parents. However, it is still possible that conditioning on the full set of pre-treatment variables can 
induce new backdoor paths between A and Y through colliders. In particular, if there is a covariate 
D that is separately confounded with the treatment A and the outcome Y then D is a collider, and 
conditioning on D opens a new backdoor path. This phenomenon is known as m-bias because of the 
shape of the graph [Pea09c], see Figure 36.4. 

A practical refinement of the pre-treatment variable heuristic is given in VanderWeele TJ [VT 11]. 
Their heuristic suggests conditioning on all pre-treatment variables that are causes of the treatment, 
outcome, or both. The essential qualifier in this heuristic is that the variable is causally upstream of 
treatment and/or outcome. This eliminates the possibility of conditioning on covariates that are 
only confounded with treatment and outcome, avoiding m-bias. Notably, this heuristic requires more 
causal knowledge than the above heuristics, but does not require detailed knowledge of how different 
covariates are causally related to each other. 

The VanderWeele TJ [VT 11] criterion is a useful rule of thumb, but other practical considerations 
often arise. For example, if one has more knowledge about the causal structure among covariates, it 
is possible to optimize adjustment sets to minimize the variance of the resulting estimator [RS20]. 
One important example of reducing variance by pruning adjustment sets is the exclusion of variables 
that are known to only be a parent of the treatment, and not of the outcome (so called instruments, 
as discussed in Section 36.5). 
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Finally, adjustment set selection criteria operate under the assumption that there actually exists a 
valid adjustment set among observed covariates. When there is no set of observed covariates in X 
that block all backdoor paths, then any adjusted estimate will be biased. Importantly, in this case, 
the bias does not necessarily decrease as one conditions on more variables. For example, conditioning 
on an instrumental variable often results in an estimate that has higher bias, in addition to the 
higher variance discussed above. This phenomenon is known as bias amplification or z-bias; see 
Section 36.7.2. A general rule of thumb is that variables that explain away much more variation in 
the treatment than in the outcome can potentially amplify bias, and should be treated with caution. 
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36.4.5.2 Overlap 


ecall that in addition to no-unobserved-confounders, identification of the average treatment effect 
requires overlap: the condition that 0 < P(A = 1|x) < 1 for the population distribution P. With 
infinite data, any amount of overlap will suffice for estimating the causal effect. In realistic settings, 
igeven near failures can be problematic. Equation (36.39) gives an expression for the (asymptotic) 
irvariance of our estimate: E[¢(X;;Q,9,7)?]/n. Notice that ¢(X;;Q,9,7)2 involves terms that are 
igproportional to 1/g(X) and 1/(1 — g(X). Accordingly, the variance of our estimator will balloon 
igif there are units where g(x) ~ 0 or g(x) ~ 1 (unless such units are rare enough that they don’t 
20contribute much to the expectation). 

21 In practice, a simple way to deal with potential overlap violation is to fit a model g for the 
22treatment assignment probability—which we need to do anyways—and check that the values g(x) 
23are not too extreme. In the case that some values are too extreme, the simplest resolution is to cheat. 
24We can simply exclude all the data with extreme values of g(x). This is equivalent to considering 
25the average treatment effect over only the subpopulation where overlap is satisfied. This changes 
26the interpretation of the estimand. The restricted subpopulation ATE may or may not provide 
27a satisfactory answer to the real-world problem at hand, and this needs to be justified based on 
28 knowledge of the real-world problem. 

29 


& la lk le lis |S |S 
jpe! 


“236.4.5.3 Choice of Estimand and Average Treatment Effect on the Treated 


32 Usually, our goal in estimating a causal effect is qualitative. We want to know what the sign of the 
33 effect is, and whether it’s large or small. The utility of the ATE is that it provides a concrete query 
34we can use to get a handle on the qualitative question. However, it is not sacrosanct; sometimes 
35 we’re better off choosing an alternative causal estimand that still answers the qualitative question but 
36 which is easier to estimate statistically. The average treatment effect on the treated or ATT, 

37 

33 ATT 2 xja [EY |X, do(A = 1)] — E[Y |X, do(A = 0)]], (36.43) 
39 

40is one such an estimand that is frequently useful. 

41 The ATT is useful when many members of the population are very unlikely to receive treatment, 
42 but the treated units had a reasonably high probability of receiving the control. This can happen if, 
43e.g., we sample control units from the general population, but the treatment units all self-selected 
44into treatment from a smaller subpopulation. In this case, it’s not possible to (non-parametrically) 
45 determine the treatment effect for the control units where no similar unit took treatment. The ATT 
46 solves this obstacle by simply omitting such units from the average. 

47 
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If we have the causal structure Figure 36.3, and the overlap condition P(A = 1|X = x) < 1 for all 
X = z then the ATT is causally identified as 


TATT — Ey) 4-1[E[Y|A = 1, X] — E[Y |A = 0, X]]. (36.44) 


Note that the required overlap condition here is weaker than for identifying the ATE. (The proof is 
the same as Theorem 2.) 

The estimation strategies for the ATE translate readily to estimation strategies for the ATT. 
Namely, estimate the nuisance functions the same way and then simply replace averages over all 
data points by averages over the treated datapoints only. In principle, it’s possible to do a little 
better than this by making use of the untreated datapoints as well. A corresponding double machine 
learning estimator is 


parr-awtw & LT PE - O04) - py AOX). (36.48) 


i 


. The variance of this estimator can be estimated by 


PTQ iD | 00,4) 


A=1) 
(1 Ai)g(X) a 
P(A=1)(1— yg(X)) (Y — Q(0, Xi) P(A=1) (36.46) 
Veer rey) A 1 5 pT (X;; Q, g, ee), (36.47) 


Notice that the estimator for the ATT doesn’t require estimating Q(1, X). This can be a considerable 
advantage when the treated units are rare. 
See Chernozhukov et al. [Che+17e] for details. 


36.4.6 Summary and Practical Advice 


We have seen a number of estimators that follow the general procedure: 


1. Fit statistical or machine-learning models Q(a, x) as a predictor for Y, and/or g(x) as a predictor 
for A 


2. Compute the predictions Q(0, Xi), ÂG, zi), ĝ(x;) for each data point, and 
3. Combine these predictions into an estimate of the average treatment effect. 


Importantly, no single estimation approach is a silver bullet. For example, the double machine- 
learning estimator has appealing theoretical properties, such as asymptotic efficiency guarantees and 
a recipe for estimating uncertainity without needing to bootstrap the model fitting. However, in 
terms of the quality of point estimates, the double ML estimators can sometimes underperform their 
more naive counterparts [KS07]. In fact, there are cases where each of outcome regression, propensity 
weighting, or doubly robust methods will outperform the others. 
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One difficulty in choosing an estimator in practice is that there are fewer guardrails in causal 
inference than there are in standard predictive modeling. In predictive modeling, we construct a 
train-test split and validate our prediction models using the true labels or outcomes in the held-out 
dataset. However, for causal problems, the causal estimands are functionals of a different data- 
generating process from the one that we actually observed. As a result, it is impossible to empirically 
validate many aspects of causal estimation using standard techniques. 

The effectiveness of a given approach is often determined by how much we trust the specification of 
our propensity score or outcome regression models g(a) and Q(a, x), and how well the treatment and 
19 control groups overlap in the dataset. Using flexible models for the nuisance functions g and Q can 
11 alleviate some of the concerns about model misspecification, but our freedom to use such models is 
120ften constrained by dataset size. When we have the luxury of large data, we can use flexible models; 
130n the other hand, when the dataset is relatively small, we may need to use a smaller parametric 
i4family or stringent regularization to obtain stable estimates of Q and g. Similarly, if overlap is poor 
15in some regions of the covariate space, then flexible models for Q may be highly variable, and inverse 
16 propensity score weights may be large. In these cases, IPTW or AIPTW estimates may fluctuate 
i7Wildly as a function of large weights. Meanwhile, outcome regression estimates will be sensitive to 
ig the specification of the Q model and its regularization, and can incur bias that is difficult to measure 
igif the specification or regularization does not match the true outcome process. 

20 There are a number of practical steps that we can take to sanity-check causal estimates. The 
21 Simplest check is to compute many different ATE estimators (e.g., outcome regression, IPTW, doubly 
22robust) using several comparably complex estimators of Q and g. We can then check whether they 
g3agree, at least qualitatively. If they do agree then this can provide some peace of mind (although it 
24is not a guarantee of accuracy). If they disagree, caution is warranted, particularly in choosing the 
a5 Specification of the Q and g models. 

26 It is also important to check for failures of overlap. Often, issues such as disagreement between 
27alternative estimators can be traced back to poor overlap. A common way to do this, particularly 
gg with high-dimensional data, is to examine the estimated (ideally cross-fitted) propensity scores ĝ(x;). 
29 This is a useful diagnostic, even if the intention is to use an outcome regression approach that only 
30incorporates and estimated outcome regression function Qa, zi). If overlap issues are relevant, it 
31may be better to instead estimate either the average treatment effect on the treated, or the “trimmed” 
32estimand given by discarding units with extreme propensities. 

33 Uncertainty quantification is also an essential part of most causal analyses. This frequently take 
34the form of an estimate of the estimator’s variance, or a confidence interval. This may be important 
35for downstream decision-making, and can also be a useful diagnostic. We can calculate variance either 
36 by bootstrapping the entire procedure (including refitting the models in each bootstrap replicate), 
37 01 computing analytical variance estimates from the AIPTW estimator. Generally, large variance 
3gestimates may indicate issues with the analysis. For example, poor overlap will often (although 
39not always) manifest as extremely large variances under either of these methods. Small variance 
4oestimates should be treated with caution, unless other checks, such as overlap checks, or stability 
41 across different Q and g models, also pass. 

42 The previous advice only addresses the statistical problem of estimating 7 from a data sample. It 
43 does not speak to whether or not 7 can reasonably be interpreted as an average treatment effect. 
44 Considerable care should be devoted to whether or not the assumption that there are no unobserved 
45 confounders is reasonable. There are several methods for assessing the sensitivity of the ATE estimate 
46to violations of this assumption. See Section 36.7. Bias due to unobserved confounding can be 
47 
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36.5. INSTRUMENTAL VARIABLE STRATEGIES 


Figure 36.5: Causal graph illustrating the Instrumental Variable setup. The treatment A and outcome Y are 
both influenced by unobserved confounder U. Nevertheless, identification is sometimes possible due to the 
presence of the instrument Z. We also allow for observed covariates X that we may need to adjust for. The 
dashed arrow between U and X indicates a statistical dependency where we remain agnostic to the particular 
causal relationship. 


substantial in practice—often overwhelming bias due to estimation error—so it is wise to conduct 
such an analysis. 


36.5 Instrumental Variable Strategies 


Adjustment-based methods rely on observing all confounders affecting the treatment and outcome. 
In some situations, it is possible to identify interesting causal effects even when there are unobserved 
confounders. We now consider strateges based on instrumental variables. The instrumental 
variable graph is shown in Figure 36.5. The key ingredient is the instrumental variable Z, a variable 
that has a causal effect on Y only through its causal effect on A. Informally, the identification 
strategy is to determine to the causal effect of Z on Y, the causal effect of Z on A, and then combine 
these into an estimate of the causal effect of A on Y. 

For this identification to strategy to work the instrument must satisfy three conditions. There are 
observed variables (confounders) X such that: 


1. Instrument Relevance Z j{ A|X: the instrument must actually affect the treatment assignment. 


2. Instrument Unconfoundedness Any backdoor path between Z and Y is blocked by X, even 
conditional on A. 


3. Exclusion Restriction All directed paths from Z to Y pass through A. That is, the instrument 
affects the outcome only through its effect on A. 


(It may help conceptually to first think through the case where X is the empty set—i.e., where 
the only confounder is the unobserved U). These assumptions are necessary for using instrumental 
variables for causal identification, but they are not quite sufficient. In practice, they must be 
supplemented by an additional assumption that depends more closely on the details of the problem 
at hand. Historically, this additional assumption was usually that both the instrument-treatment and 
treatment-outcome relationships are linear. We’ll examine some less restrictive alternatives below. 
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Before moving on to how to use instrumental variables for identification, let’s consider how we 
might encounter instruments in practice. The key is that its often possible to find, and measure, 
variables that affect treatment and that are assigned (as if) at random. For example, suppose we are 
interested in measuring the effect of taking a drug A on some health outcome Y. The challenge is that 
whether a study participant actually takes the drug can be confounded with Y—e.g., sicker people 
may be more likely to take their medication, but have worse outcomes. However, the assignment 
of treatments to patients can be randomized and this random assignment can be viewed as an 
instrument. This random assignment with non-compliance scenario is common in practice. 
io The random assignment—the instrument—satisfies relevance (so long as assigning the drug affects 
ithe probability of the patient taking the drug). It also satisfies unconfoundedness (because the 
iginstrument is randomized). And, it plausibly satisfies exclusion restriction: telling (or not telling) 
13a patient to take a drug has no effect on their health outcome except through influencing whether 
140r not they actually take the drug. As a second example, the judge fixed effects research design 
i5.uses the identity of the judge assigned to each criminal case to infer the effect of incarceration on 
iesome life outcome of interest (e.g., total lifetime earnings). Relevance will be satisfied so long as 
ız different judges have different propensities to hand out severe sentences. The assignment of trial 
igjudges to cases is randomized, so unconfoundedness will also be satisfied. And, exclusion restriction 
igis also plausible: the particular identity of the judge assigned to your case has no bearing on your 
20 years-later life outcomes, except through the particular sentence that you’re subjected to. 

21 It’s important to note that these assumptions require some care, particularly exclusion restriction. 
22 Relevance can be checked directly from the data, by fitting a model to predict the treatment from the 
23 instrument (or vice versa). Unconfoundedness is often satisfied by design: the instrument is randomly 
gaassigned. Even when literal random assignment doesn’t hold, we often restrict to instruments 
25 where unconfoundedness is “obviously” satisfied—e.g., using number of rainy days in a month as 
26 an instrument for sun exposure. Exclusion restriction is trickier. For example, it might fail in the 
27 drug assignment case if patients who are not told to take a drug respond by seeking out alternative 
agtreatment. Or, it might fail in the judge fixed effects case if judges hand out additional, unrecorded, 
29 punishments in addition to incarceration. Assessing the plausibility of exclusion restriction requires 
30 careful consideration based on domain expertise. 

31 We now return to the question of how to make use of an instrument once we have it in hand. As 
32 previously mentioned, getting causal identification using instrumental variables requires supplementing 
33the IV assumptions with some additional assumption about the causal process. 

34 

35 36.5.1 Additive Unobserved Confounding 


36 
37 We first consider additive unobserved confounding. That is, we assume that the structural caual 


3g model for the outcome has the form:* 


39 
40 


IO l% IN ID Jo Ie IW IN Ie 


Y + f(A, X) + fu (U). (36.48) 


“in words, we assume that there are no interaction effects between the treatment and the unobserved 
2 confounder—everyone responds to treatment in the same way. With this additional assumption, we 
® see that i[Y|X,do(A = a)| — E[Y|X, do(A = a’)] = f(a, X) — f(a’, X). In this setting, our goal is to 


44 . 
learn this contrast. 
45 


467. We roll the unit-specific variables € into U to avoid notational overload. 
47 
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36.5. INSTRUMENTAL VARIABLE STRATEGIES 


Theorem 3 (Additive Confounding Identification). If the instrumental variables assumptions hold 
and also additive unobserved confounding holds, then there is a function f(a,x) where 


[Y |x, do(A = a)] — E[Y |x, do(A = a’)] = f(a, x) — f(a’, £), (36.49) 


for all x,a,a’ and such that f satisfies 


[Y |z,2] = J E E A (36.50) 


Here, p(a|z,x) is the conditional probability density of treatment. 
In particular, if there is a unique function g that satisfies 


‘Vz, z] = I g(a, z)p(alz, 2)da, (36.51) 


then g = f and this relation identifies the target causal effect. 


Before giving the proof, lets understand the point of this identification result. The key insight 
is that both the left hand side of Equation (36.51) and p(a|z,x) (appearing in the integrand) are 
identified by the data, since they involve only observational relationships between observed variables. 
So, f is identified implicitly as one of the functions that makes Equation (36.51) true. If there is a 
unique such function, then this fully identifies the causal effect. 


Proof. With the additive unobserved confounding assumption, the instrument unconfoundedness 


25 implies that U L Z|X. Then, we have that: 

26 

27 [Y |Z, X] = E[f(A, X)|Z, X] + E[fu(V)|Z, X] (36.52) 
28 A A 

F = E[f(A, X)|Z, X] + E[fu(U)|X] (36.53) 
30 = E[f(A, X)|Z, X], (36.54) 
31 z 2 

32 where f = f(A, X) +E[fu(U)|X]. Now, identifying just f would suffice for us, because we could then 
33 identify contrasts between treatements: f(a,z) — f(a’,v) = f(a,z) — f(a’, x). (The term E[fy(U)|z] 
34 cancels out). Accordingly, we rewrite Equation (36.54) as: 

35 

36 UY |z, x] = | Fa, x)plalz, x)da. (36.55) 
37 

38 

39 

40 It’s worth dwelling briefly on how the IV assumptions come into play here. The exclusion restriction 
41 is implied by the additive unobserved confounding assumption, which we use explicilty. We also use 
42 the unconfoundedness assumption to conclude U JLL Z|X. However, we do not use relevance. The 
43 role of relevance here is in ensuring that few functions solve the relation Equation (36.51). Informally, 
44 the solution g is constrained by the requirement that it hold for all values of Z. However, different 
45 values of Z only add non-trivial constraints if p(a|z, x) differ depending on the value of z—this is 
46 exactly the relevance condition. 

47 
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Estimation The basic estimation strategy is to fit models for E[Y |z, x] and p(a|z, x) from the data, 
and then solve the implicit equation Equation (36.51) to find g consistent with the fitted models. 
The procedures for doing this can vary considerably depending on the particulars of the data (e.g., if 
Z is discrete or continuous) and the choice of modeling strategy. We omit a detailed discussion, but 
[see e.g. NP03; Dar+11; Har+17; SSG19; BKS19; Mua+20; Dik+20] for various concrete approaches. 

It’s also worth mentioning an additional nuance to the general procedure. Even if relevance holds, 
there will often be more than one function that satisfies Equation (36.51). So, we have only identified 
f as a member of this set of functions. In practice, this ambiguity is defeated by making some 
19 additional structural assumption about f . For example, we model f with a neural network, and then 
11 choose the network satisfying Equation (36.51) that has minimum /2-norm on the parameters (i.e., 
i2we pick the /2-regularized solution). 


IO 100 IN ID Jo e low IN Ie 
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1436.5.2 Instrument Monotonicity and Local Average Treatment Effect 


5 
_, We now consider an alternative assumption to additive unobserved confounding that is applicable 


when both the instrument and treatment are binary. It will be convenient to conceptualize the 
instrument as assignment-to-treatment. Then, the population divides into four subpopulations: 


le ISIS 


191. Compliers, who take the treatment if assigned to it, and who don’t take the treatment otherwise. 
20 
212 Always takers, who take the treatment no matter their assignment 


223. Never takers, who refuse the treatment no matter their assignment 
23 
244. Defiers, who refuse the treatment if assigned to it, and who take the treatment if not assigned. 


25 Our goal in this setting will be to identify the average treatment effect among the compliers. The 


26local average treatment effect (or complier average treatment effect) is defined to be® 
27 


23 LATE =E[Y|do(A = 1), complier] — E[Y|do(A = 0), complier]. (36.56) 
29 


30 The LATE requires an additional assumption for identification. Namely, instrument monotonic- 
zty: being assigned (not assigned) the treatment only increases (decreases) the probability that each 
— unit will take the treatment. Equivalently, P(defier) = 0. 


m We can then write down the identification result. 


34Theorem 4. Given the instrumental variable assumptions and instrument monotonicty, the local 
35average treatment is identified as a parameter TATE of the observational distributional; that is, 
36 LATE = rMATE, Namely, 


37 st P 
wee LATE _ | [Y|X,Z = 1] =Z UY |X, Z = Oj] 


7 [P(A = 1|X,Z =1)- P(A =1|X,Z = 0) 


(36.57) 


40 Proof. We now show that, given the IV assumptions and monotonicity, LATE = rMA^TE, First, notice 
4 that 


2 2 [Y |do(Z = 1)] — E[Y|do(Z = 0)] 
a ea P(A = 1|do(Z = 1)) — P(A = 1|do(Z = 0)) (36.58) 


458, We follow the econometrics literature in using “LATE” because “CATE” is already commonly used for conditional 
46 average treatment effect. 
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36.5. INSTRUMENTAL VARIABLE STRATEGIES 


This follows from backdoor adjustment, Theorem 2, applied to the numerator and denominator 
separately. Our strategy will be to decompose E[Y|do(Z = z)] into the contributions from the 
compliers, the units that ignore the instrument (the always/never takers), and the defiers. To that 
end, note that P(complier|do(Z = z)) = P(complier) and similarly for always/never takers and 
defiers—interventions on the instrument don’t change the composition of the population. Then, 


i{Y |do(Z = 1)] — E[Y |do(Z = 1)] (36.59) 
= (E[Y |complier, do(Z = 1)] — E[Y |complier, do(Z = 0)])P(complier) (36.60) 
+ (E[Y |always/never, do(Z = 1)] — E[Y |always/never, do(Z = 0)])P(always/never) (36.61) 
+ (E[Y|defier, do(Z = 1)] — E[Y |defier, do(Z = 0)])P(defier). (36.62) 


The key is the effect on the complier subpopulation, Equation (36.60). First, by definition of the 
complier population, we have that: 


[Y |complier, do(Z = z)] = E[Y |complier, do(A = z)]. (36.63) 


That is, the causal effect of the treatment is the same as the causal effect of the instrument in this 
subpopulation—this is the core reason why access to an instrument allows identification of the local 
average treatment effect. This means that 


LATE = E[Y |complier, do(Z = 1)] — E[Y |complier, do(Z = 0)]. (36.64) 
Further, we have that P(complier) = P(A = 1|do(Z = 1)) — P(A = 1|do(Z = 0)). The reason is 
simply that, by definition of the subpopulations, 

P(A = 1|do(Z = 1)) = P(complier) + P(always taker) (36.65) 

P(A = 1|do(Z = 0)) = P(always taker). (36.66) 


Now, plugging the expression for P(complier) and Equation (36.64) into Equation (36.60) we have 
that: 


(E[Y |complier, do(Z = 1)] — E[Y |complier, do(Z = 0)]) P(complier) (36.67) 
= LATE x (P(A = 1|do(Z = 1)) — P(A = 1|do(Z = 0))) (36.68) 


This gives us an expression for the local average treatment effect in terms of the effect of the instrument 
on the compliers and the probability that a unit takes the treatment when assigned /not-assigned. 

The next step is to show that the remaining instrument effect decomposition terms, Equa- 
tions (36.61) and (36.62), are both 0. Equation (36.61) is the causal effect of the instrument on the 
always/never takers. It’s equal to 0 because, by definition of this subpopulation, the instrument 
has no causal effect in the subpopulation—they ignore the instrument! Mathematically, this is just 
[Y jalways/never, do(Z = 1)] = E[Y |always/never, do(Z = 0)]. Finally, Equation (36.62) is 0 by the 
instrument monotonicity assumption: we assumed that P(defier) = 0. 

In totality, we now have that Equations (36.60) to (36.62) reduces to: 


[Y |do(Z = 1)] — E[Y |do(Z = 1)] (36.69) 


= LATE x (P(A = 1|do(Z = 1)) — P(A = 1|do(Z = 0))) +0 +0 (36.70) 


Rearranging for LATE and plugging in to Equation (36.58) gives claimed identification result. 


Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license 


1202 


1 

2 36.5.2.1 Estimation 

3 For estimating the local average treatment effect under the monotone instrument assumption, there 
4 is a double-machine learning approach that works with generic supervised learning approaches. Here, 
Š we want an estimator 7/ATE for the parameter 

6 

8 [P(A =1|X,Z =1)- P(A = 1|X,Z =0)] ` 

9 

10 To define the estimator, it’s convenient to introduce some additional notation. First, we define the 
ıı nuisance functions: 

2 

i3 H(z, 2) = E[Y|z, 2] (36.72) 
14 m(z,2) = P(A=1I2, z) (36.73) 
15 plx) = P(Z = Iz). (36.74) 
16 
17 We also define the score ¢ by: 

8 
ið Z(Y - a(1,X)) _ (1-2)Y - 4(0, X)) 
19 hg_yy(XK3 u, p) = u(1, X) — u(0,X) 4 36.75 
2 droy(Ximp) Ê uX) -MOX + y) a (36.75) 
21 ^ Z(A-—m(1,X)) _ 0- Z)(A—m(0, X)) 
aa X; = m(1,X) — m(0, X 36.76 
F p(X; H, M, P, T) Z ozsy(X; u, p) 2 bz—+a(X; m, p) XT (36.77) 
25 Then, the estimator is defined by a two stage procedure: 

26 
27l- Fit models Å, m, p for each of u, m, p (using supervised machine learning). 
282. Define MATE as the solution to 19; (Xi; fi, M, P, LATE) — 0. That is, 
29 
30 ~LATE _ I5; bz—y (Xi; À, P) 
ao yma] Da (36.78) 
32 meee bz—+A(Xi3m, p) 


wilt may help intuitions to notice that the double machine learning estimator of the LATE is effectively 
— the double machine learning estimator of of the average treatment effect of Z on Y divided by the 
~ double machine learning estimator of the average treatment effect of Z on A. 


Similarly to Section 36.4, the nuisance functions can be estimated by: 


381. fit a model f that predicts Y from Z, X by minimizing mean square error 
39 
40 2 fit a model m that predicts A from Z, X by minimizing mean cross-entropy 


“3. fit a model p that predicts Z from X by minimizing mean cross-entropy. 
42 


43 As in Section 36.4, reusing the same data for model fitting and computing the estimator can 
44potentially cause problems. This can be avoided with use a cross-fitting procedure as described in 
45 Section 36.4.2.4. In this case, we split the data into K folds and, for each fold k, use all the but 
46the kth fold to compute estimates fi_z,7m™_x,P_z of the nuisance parameters. Then we compute 
47 
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36.5. INSTRUMENTAL VARIABLE STRATEGIES 


the nuisance estimates for each datapoint 7 in fold k by predicting the required quantity using the 
nuisance model fit on the other folds. That is, if unit i is in fold k, we compute ji(z;,2;) = ATE (zi, xi) 
and so forth. 

The key result is that if we use the cross-fit version of the estimator and the estimators for the 
nuisance functions converge to their true values in the sense that 


1. E(ji(Z, X) — u(Z, X))? > 0, E(®(Z, X) — m(Z,X))? > 0, and E(p(X) — p(X))? > 0 
2. VJE[(p(X) — p(X))?] x (VE[(A(Z, X) — w(Z, X))?] + VE[(m(Z, X) — m(Z, X))?]) = of Vn) 


then (with some omitted technical conditions) we have asymptotic normality at the \/n-rate: 


eae a 


X: 
Vn(7@lATE-cf _ LATE) oe Normal (0, [O(X5 u, m, pT 


i[m(1, X) — m(0, X)]? j (36.79) 


As with double machine learning for the confounder adjustment strategy, the key point here is that 
we can achieve the (optimal) yn rate for estimating the LATE under a relatively weak condition on 
how well we estimate the nuisance functions—what matters is the product of the error in p and the 
errors in u, m. So, for example, a very good model for how the instrument is assigned (p) can make 
up for errors in the estimation of the treatment-assignment (m) and outcome (u) models. 


The double machine learning estimator also gives a recipe for quantifying uncertainity. To that 
end, define 


‘ 1 o 
f24a 2 — 2, pz» a(Xi; M, p) (36.80) 
> 1 1 
VFA] S yD (Kas À, M, p, PAT, (36.81) 
Z=>A i 


~LATE—cf 
[7 


Then, subject to suitable technical conditions, y can be used as an estimate of the variance 


of the estimator. More precisely, 
yn(? ATE — MATE) 4, Normal(0, JATE]. (36.82) 


Then, confidence intervals or p-values can be computed using this variance in the usual way. The main 
extra condition required for the variance estimator to be valid is that the nuisance parameters must 
all converge at rate O(n™1/4) (so an excellent estimator for one can’t fully compensate for terrible 
estimators of the others). In fact, even this condition is unnecessary in certain special cases—e.g., 
when p is known exactly, which occurs when the instrument is randomly assigned. See Chernozhukov 
et al. [Che+17e] for technical details. 


36.5.3 Two Stage Least Squares 


Commonly, the IV assumptions are supplemented with the following linear model assumptions: 


Ai + ao +aZi +84 Xi + yaXi t+ (36.83) 
Yi + Bo + bAi + by Xi +r Xit & (36.84) 
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That is, we assume that the real-world process for treatment assignment and the outcome are both 
linear. In this case, plugging Equation (36.83) into Equation (36.84) yields 


Y; + Bo + BaZ, + 6X, + 7Xi +Â. (36.85) 


The point is that (6, the average treatment effect of A on Y, is equal to the coefficient Ga of the 
instrument in the outcome-instrument model divided by the coefficient a of the instrument in the 
treatment-instrument model. So, to estimate the treatment effect, we simply fit both linear models 
and divide the estimated coefficients. This procedure is called two stage least squares. 
10 The simplicity of this procedure is seductive. However, the required linearity assumptions are hard 
to satisfy in practice and frequently lead to severe issues. A particularly pernicious version of this 
12is that linear-model misspecfication together with weak relevance can yield standard errors for the 
estimate that are far too small. In practice, this can lead us to find large, significant estimates from 
two stage least squares when the truth is actually a weak or null effect. See [Reil6; Youl9; ASS19; 
15Lal+21] for critical evaluations of two stage least squares in practice. 
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36.6 Difference in Differences 


S |e 


~ Unsurprisingly, time plays an important role in causality. Causes precede effects, and we should be 
~able to incorporate this knowledge into causal identification. We now turn to a particular strategy 
99 1 causal identification that relies on observing each unit at multiple time points. Data of this kind 
is sometimes called panel data. We’ll consider the simplest case. There are two time periods. In 
2 the first period, none of the units are treated, and we observe an outcome Yo; for each unit. Then, 
~a subset of the units are treated, denoted by A; = 1. In the second time period, we again observe 
~the outcomes Y4; for each unit, where now the outcomes of the treated units are affected by the 
~ treatment. Our goal is to determine the average effect receiving the treatment had on the treated 
~units. That is, we want to know the average difference between the outcomes we actually observed 
2 f ah 
~ for the treated units, and the outcomes we would have observed on those same units if they had not 
39 Deen treated. The general strategy we look at is called difference in differences.” 
~ As a concrete motivating example, consider trying to determine the effect raising minimum wage 
~on employment. The concern here is that, in an efficient labor market, increasing the price of workers 
33 Will reduce the demand for them, thereby driving down employment. As such, it seems increasing 
~ minimum wage may hurt the people the policy is nominally intended to help. The question is: how 
~ strong is this effect in practice? Card and Krueger [CK94a] studied this effect using difference in 
~_ differences. The Philadelphia metropolitan area includes regions in both Pennsylvania and New 
— Jersey (different US states). On April 1st 1992, New Jersey raised its minimum wage from $4.25 to 
— $5.05. In Pennsylvania, the wage remained constant at $4.25. The strategy is to collect employment 
~ data from fast food restaurants (which pay many employees minimum wage) in each state before 
pand after the change in minimum wage. In this case, for restaurant i, we have Yo;, the number of 
~ full time employees in February 1992, and Yj;, the number of full time employees in November 1992. 
pl he treatment is simply A; = 1 if the restaurant was located in New Jersey, and A; = 0 if located in 
q3 Pennsylvania. Our goal is to estimate the average effect of the minimum wage hike on employment 
alt the restaurants affected by it (i.e., the ones in New Jersey). 


459, See github.com/vveitch/causality-tutorials/blob/main/difference_in_ differences.ipynb. 
46 
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36.6. DIFFERENCE IN DIFFERENCES 


The assumption in classical difference-in-differences is the following structural equation: 


with E[éi:|Wi, St, Ai] = 0. Here, W; is a unit specific effect that is constant across time (e.g., the 
location of the restuarant or competence of the management) and S; is a time-specific effect that 
applies to all units (e.g., the state of the US economy at each time). Both of these quantities are 
treated as unobserved, and not explicitly accounted for. The parameter 7 captures the target causal 
effect. The (strong) assumption here is that unit, time, and treatment effects are all additive. This 
assumption is called parallel trends, because it is equivalent to assuming that, in the absence of 
treatment, the trend over time would be the same in both groups. It’s easy to see that under this 
assumption, we have: 


That is, the estimand first computes the difference across time for both the treated and untreated 
group, and then computes the difference between these differences across the groups. The obvious 
estimator is then 


f= 5 Yii — Yoi — : 5 Yii — Yoi, (36.88) 


where na is the number of treated units. 

The root identification problem addressed by difference-in-differences is that E[W;|A; = 1] 4 
([W;|A; = 0]. That is, restaurants in New Jersey may be systematically different from restuarants in 
Pennsylvania in unobserved ways that affect employment.!° This is why we can’t simply compare 
average outcomes for the treated and untreated. The identification assumption is that this unit- 
specific effect is the only source of statistical association with treatment; in particular we assume the 
time-specific effect has no such issue: E[S1; — So;|A; = 1] = E[S1; — So;|A; = 0]. Unfortunately, this 
assumption can be too strong. For instance, administrative data shows employment in Pennsylvania 
falling relative to employment in New Jersey between 1993 and 1996 [AP08, §5.2]. Although this 
doesn’t directly contradict the parallel trends assumption used for indentification, which needs to 
hold only in 1992, it does make it seem less credible. 

To weaken the assumption, we’ll look at a version that requires parallel trends to hold only after 
adjusting for covariates. To motivate this, we note that there were several different types of fast 
food restaurant included in the employment data. These vary, e.g., in the type of food they serve, 
and in cost per meal. Now, it seems reasonable the trend in employment may depend on the type 
of restuarant. For example, more expensive chains (such as Kentucky Fried Chicken) might be 
more affected by recessions than cheaper chains (such as McDonald’s). If expensive chains are more 
common in New Jersey than in Pennsylvania, this effect can create a violation of parallel trends—if 
there’s recession affecting both states, we’d expect employment to go down more in New Jersey than 
in Pennsylvania. However, we may find it credible that McDonald’s restaurants in New Jersey have 
the same trend as McDonald’s in Pennsylvania, and similarly for Kentucky Fried Chicken. 


10. This is similar to the issue that arises from unobserved confounding, except W; need not be a cause of the treatment 
assignment. 
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Figure 36.6: Causal graph assumed for the difference-in-differences setting. Here, the outcome of interest 
is the difference between the pre- and post-treatment period, Yı — Yo. This difference is influenced by the 
—treatment, unobserved factors U, and observed covariates X. The dashed arrow between U and A indicates a 
1! statistical dependency between the variables, but where we remain agnostic to the precise causal mechanism. 
18 For example, in the minimum wage example, U might be the average income in restaurant’s neighbourhood, 
19 which is dependent on the state, and hence also the treatment. 

20 

21 


22 The next step is to give a definition of the target causal effect that doesn’t depend on a parametric 
23 model, and a non-parametric statement of the identification assumption to go with it. In words, the 
24causal estimand will be the average treatment effect on the units that received the treatment. To 


25 make sense of this mathematically, we’ll introduce a new piece of notation: 
26 


27 p4=1(Y|do(A = a)) ê pre = a, parents of Y)dP(parents of Y|A = 1) (36.89) 
28 
20 EA=1Y|do(A = a)] Ê Epa= (y Jao(4=a)) [¥]. (36.90) 
30 


3iIn words: recall that the ordinary do operator works by replacing P(parents|A = a) by the marginal 
32distribution P(parents), thereby breaking the backdoor associations. Now, we’re replacing the 
33 distribution P(parents|A = a) by P(parents|A = 1), irrespective of the actual treatment value. This 
348till breaks all backdoor associations, but is a better match for our target of estimating the treatment 
35 effect only among the treated units. 

36 To formalize a causal estimand using the do calculus, we need to assume some partial causal 
37structure. We’ll use the graph in Figure 36.6. With this in hand, our causal estimand is the average 
3g treatment effect on the units that received the treatment, namely: 

39 
40 


ATT”? = E471[Y, — Yo|do(A = 1)] — E^='[Y; — Yo|do(A = 0)]] oot) 


41Tn the minimum wage example, this is the average effect of the minimum wage hike on employment 
“in the restaurants affected by it (i.e., the ones in New Jersey). 
43 Finally, we formalize the identification assumption that, conditional on X, the trends in the treated 


4 and untreated groups are the same. The conditional parallel trends assumption is: 
45 


46 E4=1[y, — Yo|X,do(A = 0)] = E[Y, — Yo|X, A = 0). (36.92) 
47 
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36.6. DIFFERENCE IN DIFFERENCES 


In words, this says that for treated units with covariates X, the trend we would have seen had we not 
assigned treatment is the same as the trend we actually saw for the untreated units with covariates 
X. That is, if New Jersey had not raised its minimum wage, then McDonald’s in New Jersey would 
have the same expected change in employment as McDonald’s in Pennsylvania. 

With this in hand, we can give the main identification result: 


Theorem 5 (Difference in Differences Identification). We observe A, Yo, Yı, X ~ P. Suppose that 


1. (Causal Structure) The data follows the causal graph in Figure 36.6. 


2. (Conditional Parallel Trends) EA=1[Y, — Yo|X,do(A = 0)] = E[Y, — Yo|X, A = 0]. 


3. (Overlap) P(A = 1) > 0 and P(A = 1|X = x) < 1 for all values of x in the sample space. That is, 
there are no covariate values that only exist in the treated group. 


Then, the average treatment effect on the treated is identified as ATT?!? = 7P!P, where 


Fi 


TPD = RJE[Y, — Yo|A = 1, X] — E[Mi — Yo|A = 0, X]|A = 1]. (36.93) 


Proof. First, by unrolling definitions, we have that 


DA=lly, — Yo|do(A = 1), X] = E — Yo|A = 1, X]. (36.94) 


The interpretation is the near-tautology that the average effect among the treated under treatment 
is equal to the actually observed average effect among the treated. Next, 


DA=1Y, — Yo|do(A = 0), X] = E[¥; — Yo|A = 0, X]. (36.95) 


is just the conditional parallel trends assumption. The result follows immediately. 
(The overlap assumption is required to make sure all the conditional expectations are well 
defined). 


36.6.1 Estimation 


With the identification result in hand, the next task is to estimate the observational estimand 
Equation (36.93). To that end, we define Ý £ Y, — Yo. Then, we've assumed that Ý, X, A ïd P for 
some unknown distribution P, and our target estimand is E[E[Y|A = 1, X] — E[Y|A = 0, X]|A = 1]. 
We can immediately recognize this as the observational estimand that occurs in estimating the 
average treatment effect through adjustment, described in Section 36.4.5.3. That is, even though 
the causal situation and the identification argument are different between the adjustment setting 
and the difference in differences setting, the statistical estimation task we end up with is the same. 
Accordingly, we can use all of the estimation tools we developed for adjustment. That is, all of the 
techniques there—expected outcome modeling, propensity score methods, double machine learning, 
and so forth—were purely about the statistical task, which is the same between the two scenarios. 
So, we’re left with the same general recipe for estimation we saw in Section 36.4.6. Namely, 


1. Fit statistical or machine-learning models Q(a, x) as a predictor for Y = Yı — Yo, and/or g(a) as 
a predictor for A 
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2. Compute the predictions Q(0, zi), ÂG, zi), ĝ(x;) for each data point, and 
3. Combine these predictions into an estimate of the average treatment effect on the treated. 


The estimator in the third step can be the expected outcome model estimator, the propensity weighted 
estimator, the double machine learning estimator, or any other strategy that’s valid in the adjustment 
setting. 
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1036.7 Credibility Checks 

1 
12Once we’ve chosen an identification strategy, fit our models, and produced an estimate, we’re faced 
13 with a basic question: should we believe it? Whether the reported estimate succeeds in capturing 
14the true causal effect depends on whether the assumptions required for causal identification hold, the 
15 quality of the machine learning models, and the variability in the estimate due to only having access 
16to a finite data sample. The latter two problems are already familiar from machine learning and 
17sStatistical practice. We should, e.g., assess our models by checking performance on held out data, 
igexamining feature importance, and so forth. Similarly, we should report measures of the uncertainity 
igdue to finite sample (e.g., in the form of confidence intervals). Because these procedures are already 
20familiar practice, we will not dwell on them further. However, model evaluation and uncertainity 
21 quantification are key parts of any credible causal analysis. 
22 Assessing the validity of identification assumptions is trickier. First, there are assumptions that 
23can in fact be checked from data. For example, overlap should be checked in analysis using backdoor 
24adjustment or difference in differences, and relevance should be checked in the instrumental variable 
25setting. Again, checking these conditions is absolutely necessary for a credible causal analysis. But, 
26again, this involves only familiar data analysis, so we will not discuss it further. Next, there are the 
27causal assumptions that cannot be verified from data; e.g., no unobserved confounding in backdoor 
28adjustment, the exclusion restriction in IV, and conditional parallel trends in DiD. Ultimately, the 
29validity of these assumptions must be assessed using substantive causal knowledge of the particular 
30 problem under consideration. However, it is possible to conduct some supplementary analyses that 
31make the required judgement easier. We now discuss two such two such techniques. 
32 


= 36.7.1 Placebo Checks 


35In many situations we may be able to find a variable that can be interepreted as a “treatment” that is 
36 known to have no effect on the outcome, but which we expect to be confounded with the outcome in 
37a very similar fashion to the true treatment of interest. For example, if we’re trying to estimate the 
38 efficacy of a COVID vaccine in preventing symptomatic COVID, we might take our placebo treatment 
39to be vaccination against HPV. We do not expect that there’s any causal effect here. However, it 
40seems plausible that latent factors that cause an individual to seek (or avoid) HPV vaccination and 
41 COVID vaccination are similar; e.g., health concientiousness, fear of needles, and so forth. Then, if 
42our identification strategy is valid for the COVID vaccine, we’d also expect it to be to be valid for 
43HPV vaccination. Accordingly, our estimation procedure we use for estimating the COVID effect 
44should, when applied to HPV, yield 7 ~ 0. Or, more precisely, the confidence interval should contain 
450. If this does not happen, then we may suspect that there are still some confounding factors lurking 
46that are not adequately handled by the identification procedure. 

47 
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36.7. CREDIBILITY CHECKS 


A similar procedure works when there is a variable that can be interpreted as an outcome which 
is known to not be affected by the treatment, but that shares confounders with the outcome we’re 
actually interested in. For example, in the COVID vaccination case, we might take the null outcome 
to be symptomatic COVID within 7 days of vaccination [Dag+21]. Our knowledge of both the 
biological mechanism of vaccination and the amount of time it takes to develop symptoms after 
COVID infection (at least 2 days) lead us to conclude that it’s unlikely that the treatment has a 
causal effect on the outcome. However, the properties of the treated people that affect how likely they 
are to develop symptomatic COVID are largely the same in the 7 day and, e.g., 6 month window. 
That includes factors such as risk aversion, baseline health, and so forth. Again, we can apply our 
identification strategy to estimate the causal effect of the treatment on the null outcome. If the 
confidence interval does not include 0, then we should doubt the credibility of the analysis. 


36.7.2 Sensitivity Analysis to Unobserved Confounding 


We now specialize to the case of estimating the average causal effect of a binary treatment by 
adjusting for confounding variables, as described in Section 36.4. In this case, causal identification 
is based on the assumption of ‘no unobserved confounding’; i.e., the assumption that the observed 
covariates include all common causes of the treatment assignment and outcome. This assumption is 
fundamentally untestable from observed data, but its violation can induce bias in the estimation of 
the treatment effect—the unobserved confounding may completely or in part explain the observed 
association. Our aim in this part is to develop a sensitivity analysis tool to aid in reasoning about 
potential bias induced by unobserved confounding. 

Intuitively, if we estimate a large positive effect then we might expect the real effect is also 
positive, even in the presence of mild unobserved confounding. For example, consider the association 
between smoking and lung cancer. One could argue that this association arises from a hormone 
that both predisposes carriers to both an increased desire to smoke and to a greater risk of lung 
cancer. However, the association between smoking and lung cancer is large—is it plausible that 
some unknown hormonal association could have a strong enough influence to explain the association? 
Cornfield et al. [Cor+59] showed that, for a particular observational dataset, such an umeasured 
hormone would need to increase the probability of smoking by at least a factor of nine. This is 
an unreasonable effect size for a hormone, so they conclude it’s unlikely the causal effect can be 
explained away. 

We would like a general procedure to allow domain experts to make judgments about whether 
plausible confounding is “mild” relative to the “large” effect. In particular, the domain expert must 
translate judgments about the strength of the unobserved confounding into judgments about the 
bias induced in the estimate of the effect. Accordingly, we must formalize what is meant by strength 
of unobserved confounding, and to show how to translate judgments about confounding strength into 
judgments about bias. 

A prototypical example, due to Imbens [Imb03] (building on [RR83]), illustrates the broad approach. 
As above, the observed data consists of a treatment A, an outcome Y, and covariates X that may 
causally affect the treatment and outcome. Imbens [Imb03] then posits an additional unobserved 
binary confounder U for each patient, and supposes that the observed data and unobserved confounder 
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1 

2 were generated according to the following assumption, known as Imbens’ Sensitivity Model: 

2 te 

4 U; © Bern(1/2) (36.96) 
5 . 

Aj|Xi,U; *®" Bern(sig(yX; + aU;)) (36.97) 
ZT ¥;|Xi, Ag, Ui “2” N (7 Ay + BX; + ôU; 0). (36.98) 
8 

9 


where sig is the sigmoid function. 

If we had observed U;, we could estimate (7,4, Â, â, 6, ô?) from the data and report 7 as the 
estimate of the average treatment effect. Since U; is not observed, it is not possible to identify 
the parameters from the data. Instead, we make (subjective) judgments about plausible values of 
~a—how strongly U; affects the treatment assignment—and d—how strongly U; affects the outcome. 
14 Contingent on plausible a = a* and 6 = 6*, the other parameters can be estimated. This yields an 
D estimate of the treatment effect 7(a*,6*) under the presumed values of the sensitivity parameters. 
16 The approach just outlined has a major drawback: it relies on a parametric model for the full data 
1 generating process. The assumed model is equivalent to assuming that, had U been observed, it 
18 would have been appropriate to use logistic regression to model treatment assignment, and linear 
12 regression to model the outcome. This assumption also implies a simple, parametric model for the 
2 relationships governing the observed data. This restriction is out of step with modern practice, where 
2l we use flexible machine-learning methods to model these relationships. For example, the assumption 
22 forbids the use of neural networks or random forests, though such methods are often state-of-the-art 


2 : : 
23 for causal effect estimation. 
24 


25 

26 Austen plots We now turn to developing an alternative an adaptation of Imbens’ approach that 
27 fully decouples sensitivity analysis and modeling of the observed data. Namely, the Austen plots 
280f [VZ20]. An example Austen plot is shown in Figure 36.7. The high-level idea is to posit a 
29 generative model that uses a simple, interpretable parametric form for the influence of the unobserved 
30confounder, but that puts no constraints on the model for the observed data. We then use the 
31parametric part of the model to formalize “confounding strength” and to compute the induced bias 
32as a function of the confounding. 

33 Austen plots further adapt two strategies pioneered by Imbens [Imb03]. First, we find a parame- 
34terization of the model so that the sensitivity parameters, measuring strength of confounding, are on 
35a standardized, unitless scale. This allows us to compare the strength of hypothetical unobserved 
36 confounding to the strength of observed covariates, measured from data. Second, we plot the curve 
37of all values of the sensitivity parameter that would yield given level of bias. This moves the analyst 
38 judgment from “what are plausible values of the sensitivity parameters?” to “are sensitivity parameters 
39this extreme plausible?” 

40 Figure 36.7, an Austen plot for an observational study of the effect of combination medications on 
41 diastolic blood pressure, illustrates the idea. A bias of 2 would suffice to undermine the qualitative 
42conclusion that the blood-pressure treatment is effective. Examining the plot, an unobserved 
43confounder as strong as age could induce this amount of confounding, but no other (group of) 
44observed confounders has so much influence. Accordingly, if a domain expert thinks an unobserved 
45confounder as strong as age is unlikely then they may conclude that the treatment is likely effective. 
46 Or, if such a confounder is plausible, they may conclude that the study fails to establish efficacy. 
47 
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36.7. CREDIBILITY CHECKS 
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Figure 36.7: Austen plot showing how strong an unobserved confounder would need to be to induce a bias 
of 2 in an observational study of the effect of combination blood pressure medications on diastolic blood 
pressure [Dor+16]. We chose this bias to equal the nominal average treatment effect estimated from the data. 
We model the outcome with Bayesian Additive Regression Trees and the treatment assignment with logistic 
regression. The curve shows all values treatment and outcome influence that would induce a bias of 2. The 
colored dots show the influence strength of (groups of) observed covariates, given all other covariates. For 
example, an unobserved confounder with as much influence as the patient’s age might induce a bias of about 
2. 


Setup The data are generated independently and identically (Y;, Ai, Xi, Ui) ae P, where U; is not 


observed and P is some unknown probability distribution. The approach in Section 36.4 assumes that 
the observed covariates X contain all common causes of Y and A. If this ‘no unobserved confounding’ 
assumption holds, then the ATE is equal to parameter, T, of the observed data distribution, where 


7 =E|E[Y|X, A = 1] -E[Y|X,A=0]]. (36.99) 


This observational parameter is then estimated from a finite data sample. Recall from Section 36.4 that 
this involves estimating the conditional expected outcome Q(A, X) = E[Y|A, X] and the propensity 
score g(X) = P(A = 1|X), then plugging these into an estimator 7. 

We are now concerned with the case of possible unobserved confounding. That is, where U causally 
affects Y and A. If there is unobserved confounding then the parameter T is not equal to the 
ATE, so 7 is a biased estimate. Inference about the ATE then divides into two tasks. First, the 
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statistical task: estimating 7 as accurately as possible from the observed data. And, second, the 
causal (domain-specific) problem of assessing bias = ATE — 7. We emphasize that our focus here is 
bias due to causal misidentification, not the statistical bias of the estimator. Our aim is to reason 
about the bias induced by unobserved confounding—the second task—in a way that imposes no 
constraints on the modeling choices for Q, g and 7 used in the statistical analysis. 


Sensitivity Model Our sensitivity analysis should impose no constraints on how the observed 
data is modeled. However, sensitivity analysis demands some assumption on the relationship between 
1°the observed data and the unobserved confounder. It is convenient to formalize such assumptions 
by specifying a probabilistic model for how the data is generated. The strength of confounding is 
12 then formalized in terms of the parameters of the model (the sensitivity parameters). Then, the 
13þias induced by the confounding can be derived from the assumed model. Our task is to posit a 
4 yenerative model that both yields a useful and easily interpretable sensitivity analysis, and that 


15 avoids imposing any assumptions about the observed data. 
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16 To begin, consider the functional form of the sensitivity model used by Imbens [Imb03]. 

17 

18 logit(P(A = 1|z, u)) = h(x) + au (36.100) 
= iY |a, x, u] = l(a, x) + du, (36.101) 


2 for some functions h and l. That is, the propensity score is logit-linear in the unobserved confounder, 
93 and the conditional expected outcome is linear. 

a By rearranging Equation (36.100) to solve for u and plugging in to Equation (36.101), we see 
tbat it’s equivalent to assume E[Y |t, x,u] = l(t, x) + dlogitP(A = 1|z,u). That is, the unobserved 
— confounder u only influences the outcome through the propensity score. Accordingly, by positing a 
2 distribution on P(A = 1|z, u) directly, we can circumvent the need to explicitly articulate U (and h). 


28 Definition 36.7.1. Let g(x, u) = P(A = 1|x,u) denote the propensity score given observed covariates 
29 


£x and the unobserved confounder u. 

30 

31 The insight is that we can posit a sensitivity model by defining a distribution on g directly. We 
32 choose: 

33 

34 g(X,U)|X ~ Beta(g(X) (1/0 — 1), (1 — g(X))G/a— 1). 

35 

36 That is, the full propensity score g(X,U) for each unit is assumed to be sampled from a Beta 
37 distribution centered at the observed propensity score g(X). The sensitivity parameter a plays the 
38same role as in Imbens’ model: it controls the influence of the unobserved confounder U on treatment 
39assignment. When a is close to 0 then g(X,U)|X is tightly concentrated around g(X), and the 
40 unobserved confounder has little influence. That is, U minimally affects our belief about who is 
41 likely to receive treatment. Conversely, when a is close to 1 then g concentrates near 0 and 1; i.e., 
42knowing U would let us accurately predict treatment assignment. Indeed, it can be shown that a is 
43the change in our belief about how likely a unit was to have gotten the treatment, given that they 
44were actually observed to be treated (or not): 

45 

46 a=E[g(X,U)|A = 1] — E[g(X,U)|A = 0]. (36.102) 
47 
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36.7. CREDIBILITY CHECKS 


With the g model in hand, we define the Austen Sensitivity Model as follows: 


G(X,U)|X ~ Beta(g(X)(1/a — 1), (1 — g(X)) (a - 1) (36.103) 
A|X,U ~ Bern(g(X,U)) (36.104) 
[Y |A, X,U] = Q(A, X) + 6(logitg(X, U) — Elogitg(X, U)|A, X]). (36.105) 


This model has been constructed to satisfy the requirement that the propensity score and conditional 
expected outcome are the g and Q actually present in the observed data: 


P(A = 1X) = E[E[T|X, U]|X] = Elg(X, U)|X] = 9(X) 
[Y |A, X] = E[E[Y|A, X, U]|A, X] = Q(A, X). 


The sensitivity parameters are a, controlling the dependence between the unobserved confounder the 
treatment assignment, and 6, controlling the relationship with the outcome. 


Bias We now turn to calculating the bias induced by unobserved confounding. By assumption, X 
and U together suffice to render the average treatment effect identifiable as: 


ATE = E[E[Y|A = 1, X,U] — E[Y|A = 0, X, UJ]. 


Plugging in our sensitivity model yields, 


ATE = E[Q(1, X) — Q(0, X)] + d(Ellogitg(X,U)|X, A = 1] — Eflogitg(X, U)|X, A = 0]). 


The first term is the observed-data estimate T, so 


bias = 5(Ellogitg(X,U)|X, A = 1] — Ellogitg(X,U)|X, A = 0). 


Then, by invoking Beta-Bernoulli conjugacy and standard Beta identities,'! we arrive at, 


Theorem 6. Under the Austen sensitivity model, Equation (36.105), an unobserved confounder with 
influence a and 6 induces bias in the estimated treatment effect equal to 


ô se l 1 
m-i ao e 


bias = 


That is, the amount of bias is determined by the sensitivity parameters and by the realized 
propensity score. Notice that more extreme propensity scores lead to more extreme bias in response 
to unobserved confounding. This means, in particular, that conditioning on a covariate that affects 
the treatment but that does not directly affect the outcome (an instrument) will increase any bias 
due to unobserved confounding. This general phenomena is known as z-bias. 


11. We also use the recurrence relation y(x + 1) — y(x) = 1/z, where w is the digamma function. 
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Sensitivity Parameters The Austen model provides a formalization of confounding strength 
in terms of the parameters œ and 6 and tells us how much bias is induced by a given strength of 
confounding. This lets us translate judgments about confounding strength to judgments about bias. 
However, it is not immediately obvious how to translate qualitative judgements such as “I think any 
unobserved confounder would be much less important than Age” to judgements about the possible 
values of the sensitivity parameters. 

First, because the scale of 6 is not fixed, it may be difficult to compare the influence of potential 
unobserved confounders to the influence of reference variables. To resolve this, we reexpress the 
19 outcome-confounder strength in terms of the (non-parametric) partial coefficient of determination: 


IO 100 IN ID Io e low IN Ie 


i(Y — E[Y|A, X, U])? 
(Y — Q(A, X)? 


Rea Gls ô) =1 


The key to computing the reparameterization is the following result 


Theorem 7. Under the Austen sensitivity model, Equation (36.105), the outcome influence is 


br gy p E OOU = 9X)!" Yo — 1) + 114 = al] 
Riola) = FD IY — OA. l 


where w is the trigamma function. 


N 


See Veitch and Zaveri [VZ20] for the proof. 

By design, a—the strength of confounding influence on on treatment assignment—is already on 
=a fixed, unitless scale. However, because the measure is tied to the model it may be difficult to 
= interpret, and it is not obvious how to compute reference confounding strength values from the 
= observed data. The next result clarifies these issues. 


N JN JN 
YS IS | 


28 Theorem 8. Under the Austen sensitivity model, Equation (36.105), 


a ay EXD) -H(%,0))] 
31 UO N 
32 


>£ See Veitch and Zaveri [VZ20] for the proof. That is, the sensitivity parameter a is measures how 
33 much more extreme the propensity scores become when we condition on U. That is, a is a measure 
34 of the extra predictive power U adds for A, above and beyond the predictive power in X. It may 
“also be insightful to notice that 


; ~ 2 
37 EN =1 alate JESU 
7 i (A — 9(X))?] 
4o That is, a is just the (non-parametric) partial coefficient of determination of U on A—the same 
4ı measure used for the outcome influence. (To see this, just expand the expectations conditional on 
4gaA=land A=0). 
43 
44Estimating bias In combination, Theorems 6 and 7 yield an expression for the bias in terms of a 
45and RẸ par In practice, we can estimate the bias induced by confounding by fitting models for Q 
46and g and replacing the expectations by means over the data. 
47 
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36.7. CREDIBILITY CHECKS 


36.7.2.1 Calibration using observed data 


The analyst must make judgments about the influence a hypothetical unobserved confounder might 
have on treatment assignment and outcome. To calibrate such judgments, we’d like to have a reference 
point for how much the observed covariates influence the treatment assignment and outcome. In the 
sensitivity model, the degree of influence is measured by partial R?- and a. We want to measure the 
degree of influence of an observed covariate Z given the other observed covariates X\Z. 

For the outcome, this can be measured as: 


R? Ay oY -QAX 
Y-ZIT,X\Z (Y — E[Y |A, X\Z])? 


In practice, we can estimate the quantity by fitting a new regression model Qz that predicts Y from 
A and X\Z. Then we compute 


1, (yi lti 2)? 
E Xii — Oz (ti, ti\z))? 
Using Theorem 8, we can measure influence of observed covariate Z on treatment assignment given 


X\Z in an analogous fashion to the outcome. We define gx\z(X\Z) = P(A = 1|X\Z), then fit a 
model for gx\z by predicting A from X\Z, and estimate 


= doi G24) (1 — G(s) 
a doi 9x\z (wi \zi)(1 — Gx\z(wi\%i)) 


Ry gin.x\z = 


Az\x\z =1 


Grouping covariates The estimated values @x\z and Ry x\z measure the influence of Z condi- 
tioned on all the other confounders. In some cases, this can be misleading. For example, if some piece 
of information is important but there are multiple covariates providing redundant measurements, 
then the estimated influence of each covariate will be small. To avoid this, group together related 
or strongly dependent covariates and compute the influence of the entire group in aggregate. For 
example, grouping income, location, and race as ‘socioeconomic variables’. 


36.7.2.2 Practical Use 


We now have sufficient results to produce Austen plots such as Figure 36.7. At a high level, the 
procedure is: 


1. Produce an estimate 7 using any modeling tools. As a component of this, estimate the propensity 
score g and conditional outcome model Q 


2. Pick a level of bias that would suffice to change the qualitative interpretation of the estimate (e.g., 
the lower bound of a 95% confidence interval). 


3. Plot the values of a and By wa that would suffice to induce that much bias. This is the black 
curve on the plot. To calculate these values, use Theorems 6 and 7 together with the estimated g 
and Q. 
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4. Finally, compute reference influence level for (groups of) observed covariates. In particular, this 
requires fitting reduced models for the conditional expected outcome and propensity that do not 
use the reference covariate as a feature. 


In practice, an analyst only needs to do the model fitting parts themselves. The bias calculations, 
reference value calculations, and plotting can be done automatically with standard libraries.!?. 

Austen plots are predicated on Equation (36.105). This assumption replaces the purely parametric 
Equation (36.98) with a version that eliminates any parametric requirements on the observed data. 
However, we emphasize that Equation (36.105) does, implicitly, impose some parametric assumption 
on the structural causal relationship between U and A,Y. Ultimately, any conclusion drawn from 
~ the sensitivity analysis depends on this assumption, which is not justified on any substantive grounds. 
~ Accordingly, such sensitivity analyses can only be used to informally guide domain experts. They 
~ do not circumvent the need to thoroughly adjust for confounding. This reliance on a structural 
assumption is a generic property of sensitivity analysis.'° Indeed, there are now many sensitivity 
analysis models that allow the use of any machine learning model in the data analysis [e.g., RRS00; 
~FDF19; She+11; HS13; BK19; Ros10; Yad+18; ZSB19; Sch+21a]. However, none of these are yet in 
routine use in practice. We have presented Austen plots here not because they make an especially 
virtuous modeling assumption, but because they are (relatively) easy to understand and interpret. 

Austen plots are most useful in situations where the conclusion from the plot would be ‘obvious 
— to a domain expert. For instance, in Figure 36.7, we can be confident that an unobserved confounder 
99 Similar to socioeconomic status would not induce enough bias to change the qualitative conclusion. 
~ By contrast, Austen plots should not be used to draw conclusions such as, “I think a latent confounder 
24 could only be 90% as strong as ‘age’, so there is evidence of a small non-zero effect”. Such nuanced 
z5 Conclusions might depend on issues such as the particular sensitivity model we use, or finite-sample 
~ variation of our bias and influence estimates, or on incautious interpretation of the calibration dots. 
27T hese issues are subtle, and it would be difficult resolve them to a sufficient degree that a sensitivity 
og analysis would make an analysis credible. 
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S Calibration using observed data The interpretation of the observed-data calibration requires 
150me care. The sensitivity analysis requires the analyst to make judgements about the strength 
32 Of influence of the unobserved confounder U, conditional on the observed covariates X. However, 
33We report the strength of influence of observed covariate(s) Z, conditional on the other observed 
34 covariates X\Z. The difference in conditioning sets can have subtle effects. 

a, Cinelli and Hazlett [CH20] give an example where Z and U are identical variables in the true 
36 model, but where influence of U given A, X is larger than the influence of Z given A, X\Z. (The 
37influence of Z given X\Z,U would be the same as the influence of U given X). Accordingly, an 
3g analyst is not justified in a judgement such as, “I know that U and Z are very similar. I see Z has 
39 Substantial influence, but the dot is below the line. Thus, U will not undo the study conclusions.” In 
49 essence, if the domain expert suspects a strong interaction between U and Z then naively eyeballing 
4, the dot-vs-line position may be misleading. A particular subtle case is when U and Z are independent 
4g Variables that both strongly influence A and Y. The joint influence on A creates an interaction effect 


43 between them when A is conditioned on (the treatment is a collider). This affects the interpretation 


AA 2.5 as 

~~ 12. See github.com/vveitch/causality-tutorials/blob/main/SensitivityAnalysis.ipynb. 

4513. In extreme cases, there can be so little unexplained variation in A or Y that only a very weak confounder could be 
46 compatible with the data. In this case, essentially assumption free sensitivity analysis is possible |Man90]. 
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36.8. THE DO CALCULUS 


of Ry XA Indeed, we should generally be skeptical of sensitivity analysis interpretation when it is 
expected that a strong confounder has been omitted. In such cases, our conclusions may depend 
substantively on the particular form of our sensitivity model, or other unjustifiable assumptions. 

Although the interaction problem is conceptually important, its practical significance is unclear. 
We often expect the opposite effect: if U and Z are dependent (e.g., race and wealth) then omitting U 
should increase the apparent importance of Z—leading to a conservative judgement (a dot artifically 
towards the top right part of the plot). 


36.8 The Do Calculus 


We have seen several strategies for identifying causal effects as parameters of observational distribu- 
tions. Confounder adjustment (Section 36.4) relied only on the assumed causal graph (and overlap), 
which specified that we observe all common causes of A and Y. On the other hand, instrumental 
variable methods and difference-in-differences each relied on both an assumed causal graph and 
partial functional form assumptions about the underlying structural causal model. Because functional 
form assumptions can be quite difficult to justify on substantive grounds, it’s natural to ask when 
causal identification is possible from the causal graph alone. That is, when can we be agnostic to the 
particular functional form of the structural causal models? 

There is a general “calculus of intervention”, known as the do-calculus, that gives a general 
recipe for determining when the causal assumptions expressed in a causal graph can be used to 
identify causal effects [Pea09c]. The do-calculus is a set of three rewrite rules that allows us to 
replace statements where we condition on variables being set by intervention, e.g. P(Y|do(A = a)), 
with statements involving only observational quantities, e.g. Ex[P(Y|A = a,X)]. When causal 
identification is possible, we can repeatedly apply the three rules to boil down our target causal 
parameter into an expression involving only the observational distribution. 


36.8.1 The three rules 


To express the rules, let X, Y, Z, and W be arbitrary disjoint sets of variables in a causal DAG G. 


Rule 1 The first rule allows us to insert or delete observations z: 
p(y|do(x), z,w) = p(y|do(x), w) if (Y L 2|X,W)a, (36.107) 


where Gy denotes cuting edges going into X, and (Y L Z|X,W)g, denotes conditional independence 
in the mutilated graph. The rule follows from d-separation in the mutilated graph. This rule just 
says that conditioniong on irrelevant variables leaves the distribution invariant (as we would expect). 


Rule 2 The second rule allows us to replace do(z) with conditioning on (seeing) z. The simplest 
case where can do this is: if Z is a root of the causal graph (i.e., it has no causal parents) then 
p(y|do(z)) = p(y|z). The reason is that the do operator is equivalent to conditioning in the mutilated 
causal graph where all the edges into Z are removed, but, because Z is a root, the mutilated graph 
is just the original causal graph. The general form of this rule is: 


p(y|do(«), do(z), w) = p(y|do(a), z,w) if (Y L ZX, W ese (36.108) 
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where Gy, cuts edges going into X and out of Z. Intutively, we can replace do(z) by z as long as 
there are no backdoor (non-directed) paths between z and y. If there are in fact no such paths, then 
cutting all the edges going out of Z will mean there are no paths connecting Z and Y, so that Y L Z. 
The rule just generalizes this line of reasoning to allow for extra observed and intervened variables. 


Rule 3 The third rule allows us to insert or delete actions do(z): 


IO 100 IN ID Jo Ie IW IN Ie 


p(yldo(x), do(z), w) = p(yldo(#), w) if (Y L Z|X,Wa zz (36.109) 
10 
iwhere Gyg- cuts edges going into X and Z*, and where Z* is the set of Z-nodes that are not 
agancestors of any W-node in Gx. Intuitively, this condition corresponds to intervening on X, and 
13 checking whether the distribution of Y is invariant to any intervention that we could apply on Z. 
4 


1536.8.2 Revisiting Backdoor Adjustment 


6 

17 We begin with a more general form of the adjustment formula we used in Section 36.4. 

First, suppose we observe all of A’s parents, call them X. For notational simplicity, we’ll assume 
~ for the moment that X is discrete. Then, 


a P(Y = yldo(A = a) = X p(Y = ylz, do(A = a))p(z|do(A = a) (36.110) 
i = $ p(Y = ylz, A = a)p(z). (36.111) 
24 x 


>The first line is just a standard probability relation (marginalizing over z). We are using causal 
26 assumptions in two ways in the second line. First, p(2|do(A = a)) = p(x): the treatment has 
2Tno causal effect on Z , so interventions on A don’t change the distribution of Z. This is rule 3, 
28 Equation (36.109). Second, p(Y = y|z,do(A = a)) = p(Y = y|z, A =a). This equality holds because 
2? conditioning on the parents blocks all non-directed paths from A to Y, reducing the causal effect to 
30 be the same as the observational effect. The equality is an application of rule 2, Equation (36.108). 
31 Now, what if we don’t observe all the parents of A? The key issue is backdoor paths: paths 
32 between A and Y that contain an arrow into A. These paths are the general form of the problem 
33that occurs when A and Y share a common cause. Suppose that we can find a set of variables S 
34 such that (1) no node in § is a descendant of A; and (2) S blocks every backdoor path between A 
® and Y. Such a set is said to satisfy the backdoor criterion. In this case, we can use S' instead of 
"the parents of X in the adjustment formula, Equation (36.111). That is, 

38 p(Y = y|do(A = a)) = Eg[p(Y = y|S, A = a)]. (36.112) 
39 

40 The proof follows the invocation of rules 3 and 2, in the same way as for the case where S is just the 
4iparents of A. Notice that requiring S to not contain any descendants of A means that we don’t risk 
42 conditioning on any variables that mediate the effect, nor any variables that might be colliders—either 
43 would undermine the estimate. 

44 The backdoor adjustment formula generalizes the adjust-for-parents approach and adjust-for-all- 
45common-causes approach of Section 36.4. That’s because both the parents of A and the common 
46 causes satisfy the backdoor criterion. 
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36.8. THE DO CALCULUS 


Figure 36.8: Causal graph illustrating the frontdoor criterion setup. The effect of the treatment A on outcome 
Y is entirely mediated by mediator M. This allows us infer the causal effect even if the treatment and outcome 
are confounded by U. 


In practice, the full distribution p(Y = y|do(A = a)) is rarely used as the causal target. Instead, 
we try to estimate a low-dimensional parameter of this distribution, such as the average treatment 
effect. The adjustment formula immediately translates in the obvious way. If we define 


7 = Es[E[Y |A = 1, 5] — E[Y |A = 0, S]], 


then we have that ATE = 7 whenever S satisfies the backdoor criteria. The parameter 7 can then 
be estimated from finite data using the methods described in Section 36.4, using S in place of the 
common causes X. 


36.8.3 Frontdoor Adjustment 


Backdoor adjustment is applicable if there’s at least one observed variable on every backdoor path 
between A and Y. As we have seen, identification is sometimes still possible even when this condition 
doesn’t hold. Frontdoor adjustment is another strategy of this kind. Figure 36.8 shows the causal 
structure that allows this kind of adjustment strategy. Suppose we’re interested in the effect of 
smoking A on developing cancer Y, but we’re concerned about some latent genetic confounder U. 

Suppose that all of the directed paths from A to Y pass through some set of variables M. Such 
variables are called mediators. For example, the effect of smoking on lung cancer might be entirely 
mediated by the amount of tar in the lungs and measured tissue damage. It turns out that if all 
such mediators are observed, and the mediators do not have an unobserved common cause with A or 
Y, then causal identification is possible. To understand why this is true, first notice that we can 
identify the causal effect of A on M and the causal effect of M on A, both by backdoor adjustment. 
Further, the mechanism of action of A on Y is: A changes M which in turn changes Y. Then, we 
can combine these as: 


p(Y |\do(A = a)) = X v(¥ |do(M = m))p(M = m|do(A = a)) (36.113) 


=X J pY a’, m)p(a')p(mla) (36.114) 


m a 
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The second line is just backdoor adjustment applied to identify each of the do expressions (note that 
A blocks the M-Y backdoor path through U). 

Equation (36.114) is called the front-door formula [Pea09b, §3.3.2]. To state the result in more 
general terms, let us introduce a definition. We say a set of variables M satisfies the front-door 
criterion relative to an ordered pair of variables (A, Y) if (1) M intercepts all directed paths from 
A to Y; (2) there is no unblocked backdoor path from A to M; and (3) all backdoor paths from M 
to Y are blocked by A. If M satisfies this criterion, and if p(A, M) > 0 for all values of A and M, 
then the causal effect of A on Y is identifiable and is given by Equation (36.114). 

10 Let us interpret this theorem in terms of our smoking example. Condition 1 means that smoking 
11A should have no effect on cancer Y except via tar and tissue damage M. Conditions 2 and 3 mean 
igthat the genotype U cannot have any effect on M except via smoking A. Finally, the requirement 
izthat p(A, M) > 0 for all values implies that high levels of tar in the lungs must arise not only due to 
1asmoking, but also other factors (e.g., pollutants). In other words, we require p(A = 0, M = 1) > 0 so 
15 we can assess the impact of the mediator in the untreated setting. 

We can now use the do-calculus to derive the frontdoor criterion; following [PM18b, p236]. Assuming 
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crite causal graph G shown in Figure 36.8: 

i? (y|do(a =3 p% (yldo(a), m)p(m|do(a)) (probability axioms) 
a S IEO (rule 2 using Gy) 
22 m 

a =X rlvldo(e) do(m))p(mla) (rule 2 using Gs) 
7 = X pU do(m))p(m|a) (rule 3 using Gay”) 
7 = > 5 p(y|do(m), a')p(a'|do(m))p(m]a) (probability axioms) 
29 ao om 

20 = Ð Ye plylm,a’)p(a’|do(m))p(mla) (rule 2 using Gr) 
31 a’ m 

7 — D X` plylm, a’)p(a’)p(mla) (rule 3 using Gy) 
34 


35Estimation To estimate the causal distribution from data using the frontdoor criterion we need to 
36 estimate each of p(y|m,a), p(a), and p(mla). In practice, we can fit models p(y|m, a) by predicting 
37Y from M and A, and f(mla) by predicting M from A. Then, using the empirical distribution to 


38 estimate p(a), the final estimate is: 
39 


F il JSZ“ (y|m, a’)p(mla), (36.115) 


42 
43where |A| is the number of treatments. 
44 We usually have more modest targets than the full distribution p(y|do(a)). For instance, we may 
45be content with just estimating the average treatment effect. It’s straightforward to derive a formula 
46 for this using the frontdoor adjustment. Similarly to backdoor adjustment, more advanced estimators 
47 
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36.9. FURTHER READING 


of the ATE through frontdoor effect are possible in principle. For example, we might combine fitted 
models for E[Y|m, a] and P(M|a). See Fulcher et al. [Ful+20] for an approach to robust estimation 
via front door adjustment, as well as a generalization of the front door approach to more general 
settings. 


36.9 Further Reading 


There is an enormous and growing literature on the intersection of causality and machine learning. 
First, there are many textbooks on theoretical and practical elements of causal inference. These 
include Pearl [Pea09c], focused on causal graphs, Angrist and Pischke [AP08], focused on econometrics, 
Hernan and Robins [HR20b], with roots in epidemiology, Imbens and Rubin [IR15], with origin in 
statistics, and Morgan and Winship [MW15], for a social sciences perspective. The introduction to 
causality in Shalizi [Sha22, §7] is also recommended, particularly the treatment of matching. 

Double machine-learning has featured prominently in this chapter. This is a particular instantiation 
of non-parametric estimation. This topic has substantial theoretical and practical importance in 
modern causal inference. The double machine learning work includes estimators for many commonly 
encountered scenarios [Che+17e; Che+17d]. Good references for a lucid explanation of how and 
why non-parametric estimation works include [Ken16; Ken17; FK21]. Usually, the key guarantees of 
non-parametric estimator are asymptotic. Generally, there are many estimators that share optimal 
asymptotic guarantees (e.g. the AIPTW estimator given in Equation (36.30)). Although these are 
asymptotically equivalent, in finite samples their behavior can be very different. There are estimators 
that preserve asymptotic guarantees but aim to improve performance in practical finite sample 
regimes [e.g., vR11]. 

There is also considerable interest in the estimation of heterogeneous treatment effects. The 
question here is: what effect would this treatment have when applied to a unit with such-and-such 
specific characteristics? E.g., what is the effect of this drug on women over the age of 50? The causal 
identification arguments used here are more-or-less the same as for the estimation of average case 
effects. However, the estimation problems can be substantially more involved. Some reading includes 
[Kiin+19; NW20; Ken20; Yad+21]. 

There are several commonly applicable causal identification and estimation strategies beyond the 
ones we’ve covered in this chapter. Regression discontinuity designs rely on the presence of 
some sharp, arbitrary non-linearity in treatment assignment. For example, eligibility for some aid 
programs is determined by whether an individual has income below or above a fixed amount. The 
effect of the treatment can be studied by comparing units just below and just above this threshhold. 
Synthetic controls are a class of methods that try to study the effect of a treatment on a given 
unit by constructing a synthetic version of that unit that acts as a control. For example, to study the 
effect of legislation banning smoking indoors in California, we can construct a synthetic California 
as a weighted average of other states, with weights chosen to balance demographic characteristics. 
Then, we can compare the observed outcome of California with the outcome of the synthetic control, 
constructed as the weighted average of the outcomes of the donor states. See Angrist and Pischke 
[APO08] for a textbook treatment of both strategies. Closely related are methods that use time series 
modeling to create synthetic outcomes. For example, to study the effect of an advertising campaign 
beginning at time T on product sales Y;, we might build a time series model for Y; using data in the 
t < T period, and then use this model to predict the values of (Yjisr we would have seen had the 
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campaign not been run. We can estimate the causal effect by comparing the factual, realized Y; to 
the predicted, counterfactual, Y;. See Brodersen et al. [Bro+15] for an instantiation of this idea. 

In this chapter, our focus has been on using machine learning tools to estimate causal effects. 
There is also a growing interest in using the ideas of causality to improve machine learning tools. 
This is mainly aimed at building predictors that are robust when deployed in new domains [SS18b; 
SCS19; Arj+20; Meil8b; PBM16a; RC+18; Zha+13a; Sch+12b; Vei+21] or that do not rely on 
particular ‘spurious’ correlations in the training data [RPH21; Wu+21; Gar+19; Mit+20; WZ19; 
KCC20; KHL20; TAH20; Vei+21]. 
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a 436 

25 Automatic Relevance Determina- 
26 tion, 645 

~~ automatic relevance determination, 
27 

= 905 

28 automatic relevancy determination, 
a 573, 612 

29 automatic speech recognition, 344, 
30 736, 956 


~~ autoregressive, 733 
2- autoregressive bijection, 801 
32 autoregressive flows, 800 
__ autoregressive models, 876 
33 auxiliary latent variables, 440 
34 auxiliary variable deep generative 
— model, 441 
35 auxiliary variables, 483 
36 average causal effect, 1183 
~~ Average Treatment Effect, 1174 
2l average treatment effect, 1182 
38 average treatment effect on the 
a treated, 1194 
39 axis aligned, 14 
40 
Al BA lower bound, 220 
— back translation, 706 
42 backbone, 603 
backcasting, 988 

— backdoor criterion, 1218 
44 backdoor paths, 1218 

5 backoff smoothing, 105 
— backpropagation, 242 
46 backup diagram, 1126 
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backwards Kalman gain, 353 

backwards kernel, 534 

backwards transfer, 722 

BACS, 711 

bagging, 621 

balance the dataset, 700 

BALD, 1131 

Bambi, 589 

BAMDP, 1139 

bandit problem, 1110 

bandwidth, 738 

Barlow Twins, 1061 

base distribution, 793 

base measure, 30, 1010 

base-rate, 1037 

baseline, 465 

baseline function, 252 

basic random variables, 182 

basin flooding, 283 

basis functions, 933 

batch active learning, 1131 

batch ensemble, 622 

batch normalization, 597 

batch optimization, 251 

batch reinforcement learning, 1161 

BatchBALD, 1132 

batched Bayesian optimization, 
280 

Baum-Welch, 944, 946 

Baum-Welch algorithm, 137 

Bayes ball algorithm, 125 

Bayes by backprop, 615 

Bayes estimator, 1099 

Bayes factor, 108 

Bayes filter, 335 

Bayes nets, 119 

Bayes’ rule, 66 

Bayes’ rule for Gaussians, 20 

Bayes-adaptive MDP, 1139 

Bayes-Newton, 373 

BayesBiNN, 274 

Bayesian active learning by dis- 
agreement, 1131 

Bayesian approach, 69 

Bayesian dark knowledge, 624 

Bayesian decision theory, 1099 

Bayesian deep learning, 609 

Bayesian factor regression, 909 

Bayesian hypothesis testing, 108 

Bayesian inference, 1, 66 

Bayesian information criterion, 
113 

Bayesian lasso, 571 

Bayesian learning rule, 267 

Bayesian model selection, 108 

Bayesian multi net, 140 

Bayesian networks, 119 

Bayesian nonparametric, 1009 

Bayesian nonparametric models, 

509 

Bayesian Occam’s razor, 108 

Bayesian online changepoint detec- 


tion, 961 

Bayesian optimization, 274, 288, 
643 

Bayesian Optimization Algorithm, 
288 


Bayesian p-value, 115 

Bayesian quadature, 467 

Bayesian statistics, 65 

Bayesian structural time series, 
982 

Bayesian transfer learning, 612 

BayesOpt, 274 

BBB, 615 

BBMM, 675 

BBVI, 430 

BDL, 609 

beam search, 1020 

Bean Machine, 187 

bearings only tracking problem, 
976 

behavior cloning, 1170 

behavior policy, 1161 

behavior-agnostic off-policy, 1161 

belief networks, 119 

belief propagation, 385 

belief state, 335, 552, 1113, 1139 

belief states, 385 

belief-state MDP, 1114 

Bellman backup, 1124 

Bellman error, 1122 

Bellman residual, 1122 

Bellman’s optimality equations, 
1122 

Berkson’s paradox, 122, 127 

Bernoulli bandit, 1114 

Bernoulli distribution, 5 

Bernoulli mixture model, 888 

BERT, 606, 769, 1056 

Bessel function, 28, 645 

best arm identification, 274, 1108 

best-arm identification, 1112 

best-first search, 1154 

beta distribution, 13 

beta function, 8 

Beta process, 1027 

beta-VAE, 760 

bi-directed graph, 177 

BIC, 113, 683 

BIC loss, 114 

BIC score, 113 

big data, 72 

BiGAN, 868, 1054 

bigram model, 49, 50, 213 

bigram statistics, 51 

bijection, 45, 794 

bilinear form, 601 

binary entropy function, 210 

binary logistic regression, 577 

binary neural network, 274 

binomial coefficient, 5 

binomial distribution, 5 

binomial regression, 560 

bit error, 390 

bits, 200, 210 

bits back coding, 226 

bits per dimension, 742 

BIVA, 775 

bivariate Gaussian, 14 

black box attack, 727 

black box shift estimation, 710 

black box variational inference, 
430 
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blackbox, 430 

blackbox EP, 448 

blackbox matrix-matrix multiplica- 
tion, 675 

Blackwell-MacQueen, 1013 

blind inverse problem, 893 

blind source separation, 928 

block length, 227 

block stacking, 1157 

blocked Gibbs sampling, 481 

BLOG, 187 

BN, 597 

BNP, 1009 

BOA, 288 

Bochner’s theorem, 648 

BOCPD, 961 

Boltzmann machine, 147 

Boltzmann policy, 1138 

bond variables, 485 

Bonnet’s theorem, 249, 254 

boolean  satisfiability problems, 

282 

bootstrap filter, 519 

bootstrap sampling, 620 

bootstrapping, 1137, 1141 

borrow statistical strength, 97 

bottom-up inference model, 775 

bound optimization, 257, 260 

Box-Muller, 457 

BP, 385 

branching factor, 213 

Bregman divergence, 208, 209 

Brier score, 548 

BRMS, 589 

Brownian motion, 492, 493, 646, 

1034 

Brownian noise, 503 

BSS, 928 

BSTS, 982 

bucket elimination, 401 

BUGS, 475 

building blocks, 288, 593 

burn-in phase, 494 

burn-in time, 469 

burstiness, 1036 

BYOL, 1062 


calculus of intervention, 1217 

calculus of variations, 39 

calibrated, 548 

calibration set, 556 

canonical correlation analysis, 911, 
1047 

canonical form, 16, 30 

canonical link function, 562 

canonical parameters, 16, 30, 34, 
153 

CAQL, 1123 

cart-pole swing-up, 1156 

casino HMM, 937 

CASP, 146 

catastrophic forgetting, 721 

categorical, 6 

categorical distribution, 74 

categorical PCA, 913, 923 

CatPCA, 913 

Cauchy, 9 
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Cauchy sequence, 659 

causal convolution, 787 

causal DAGs, 697 

causal discovery, 1007 

Causal graphs, 1176 

causal hierarchy, 191 

causal impact, 193, 988 

Causal inference, 1173 

causal Markov assumption, 188 

causal models, 188 

causal prediction, 697 

causal representation learning, 
1064 

causally sufficient, 188 

causes of effects, 191 

CAVI, 413 

cavity distribution, 447 

CCA, 911 

cdf, 7 

CEB, 231 

CelebA, 740, 740, 756 

Centered kernel alignment (CKA), 
1047 

centering matrix, 82 

central composite design, 680 

central limit theorem, 455, 689 

central moment, 9 

certify, 729 

ceteris paribus, 192 

chain components, 174 

chain compositions, 238 

chain graph, 158, 174, 176 

chain rule, 237 

chance nodes, 1102 

change of variables, 45 

changepoint detection, 700, 959, 

961 

channel coding, 195, 227 

channel coding theorem, 399 

Chapman-Kolmogorov, 49 

Chapman-Kolmogorov equation, 
336 

characteristic length scale, 645 

Chernoff-Hoeffding inequality, 
1115 

chi-squared distance, 58 

Chi-squared distribution, 11 

children, 119 

Chinese restaurant process, 1014 

choice theory, 584 

Cholesky decomposition, 457 

Chomsky normal form, 165, 166 

chordal, 171 

Chow-Liu algorithm, 288 

chromosomes, 285 

CI, 119 

circuits, 240 

circular flow, 808 

circular normal, 28 

citation matching, 187 

CKF, 369 

clamped phase, 159, 817 

class incremental learning, 721 

classical statistics, 65 

classifier guidance, 842 

classifier-free guidance, 843 

Claude Shannon, 223 


clausal form, 185 

click through rate, 1109, 1112 

clinical trials, 1112 

CLIP, 791, 1060 

clique, 141 

clique tree, 407 

cliques, 404 

closed world assumption, 186, 974 

closing the loop, 532 

closure, 155 

cluster variational method, 397 

clustering, 886 

clutter problem, 330 

CMA-ES, 290, 1158 

CMGF, 372 

CNN, 603 

co-information, 217 

co-parents, 129 

coagulation, 1024 

coalescence, 520 

cocktail party problem, 928 

code words, 223 

codebook, 780 

codebook loss, 781 

codewords, 195 

coffee, lemon, milk, and tea, 299 

cold posterior, 624 

cold start problem, 72 

collapsed, 137 

collapsed Gibbs sampler, 481, 

1016 

collapsed particles, 526 

collective classification, 184 

collider, 125, 1177 

collocation, 1155 

coloring, 475 

commitment loss, 782 

common corruptions, 695 

common random number, 1157 

common random numbers, 464 

common random numbers trick, 
437 

compact support, 675 

Compactness, Sparsity, 1081 

compatible, 1153 

complementary log-log, 562 

complete, 659 

complete data, 158 

completely random measures, 
1031 

Completeness, 1081, 1083 

completing the square, 84 

complexity penalty, 112 

complier average treatment effect, 
1200 

components, 979 

composite likelihood, 160 

compositional pattern-producing 
network, 727 

Compression Lemma, 203 

computation graph, 593 

computation tree, 394 

concave, 38 

concentration inequality, 1115 

concentration of measure, 730 

concentration parameter, 1010 

concept drift, 719 
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concept shift, 637, 699 
concrete distribution, 255 
condensation, 519 
conditional entropy bottleck, 231 
conditional expected outcome, 
1186 
conditional GAN, 867 
conditional generative model, 733, 
736 
Conditional generative models, 
866 
conditional independence, 119 
10 conditional KL divergence, 199 
-~ conditional log marginal likelihood, 
i 110 
12 conditional moments, 371 
“conditional moments Gaussian fil- 
13 ter, 372 
14 conditional parallel trends, 1206 
“conditional probability distribu- 
15 tion, 120 
16 conditional probability table, 121 
__ conditional random field, 162, 
au 1101 
18 conditional random fields, 141, 142 
~~ Conditional shift, 699 
19 conditional value at risk, 1135 
20 conditionally conjugate, 82 
5, conditioner, 800, 801 
£- conditioning case, 121 
22 conditioning matrix, 243 
conductance, 495 
23 con fidence score, 702 
24conformal prediction, 555, 701, 
702 
25 con ormal score, 555 
26 conformalized quantile regression, 
558 
2E cori rounder, 177 
28 confounders, 1184 
conical combination, 798 
= conjugate, 72, 323 
30 conjugate gradients, 674 
3 conjugate prior, 20, 30, 72 
=— conjunction of features, 814 
32 conjunctive normal form, 185 
consensus sequence, 942 
—= conservative policy iteration, 1151 
34 consistent, 1010 
constant symbols, 185 
—contact map, 146 
36 content addressable memory, 147 
content constrained, 728 
— context free grammar, 165 
38 context variables, 714 
Context., 1069 
— contextual bandit, 1111 
40 continual learning, 632, 705, 718 
1 continuation method, 511 
— continuing task, 1120 
42 continuous task-agnostic learning, 
43 l 721 
— continuous-ranked 
44 score, 992 
continuous-time flows, 808 
—contraction, 1124 
46 contrastive divergence, 160, 816 
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contrastive learning, 1058 

Contrastive Multiview Coding, 
1060 

Contrastiveness, 1083 

control, 1105 

control as inference, 1167 

control theory, 1120 

control variate, 465, 504 

control variates, 252, 431 

controller, 1120 

converge, 495 

conversions, 1109 

convex BP, 397 

convex combination, 73 

ConvNeXt, 603 

convolution, 595 

convolutional layer, 596 

convolutional Markov model, 787 

convolutional neural network, 603, 
786 

convolutional neural networks, 167 

cooling schedule, 512 

cooperative cut, 305 

coordinate ascent variational infer- 
ence, 413 

coreset, 666 

correlated topic model, 923 

correlation coefficient, 14 

correspondence, 973 

cosine distance, 28 

cosine kernel, 647 

count-based exploration, 1138 

Counterfactual queries, 1180 

counterfactual question, 191 

counterfactual reasoning, 987 

coupled HMM, 958 

coupling flows, 799 

coupling layer, 799 

covariance function, 641 

covariance graph, 177, 222 

covariance matrix, 14 

covariate shift, 698 

coverage, 556 

Cox process, 664 

CPD, 120 

CPPN, 727 

CPT, 121 

CQR, 558 

credible interval, 68, 80 

CRF, 141, 162 

CRFs, 142 

critic, 850 

critical temperature, 144, 486 

cross correlation, 595 

cross entropy, 206, 213 

cross entropy method, 287, 1155 

cross fitting, 1190 

cross validation, 109 

cross-entropy method, 289, 539 

cross-stitch networks, 712 

crossover, 285 

crowd computing, 1127 

CRP, 1014 

CRPS, 992 

CTR, 1112 

cubature, 453 

cubature Kalman filter, 369 


cubatures, 368 

cumulants, 35 

cumulative distribution function, 7 
cumulative regret, 1118 
cumulative reward, 1111 

curse of dimensionality, 738, 1124 
curse of horizon, 1164 

curved exponential family, 30 

CV, 431 

CVI, 270 

cycle consistency, 875 

cyclical annealing, 772 


d-separated, 125 

D4PG, 1153 

DAGs, 119 

DALL-E, 790 

DALL-E 2, 792, 844 

damped updates, 417 
damping, 395, 449 

dark knowledge, 624 

DARN, 123 

data assimilation, 374 

data association, 973 

data augmentation, 586, 706 
data cleaning, 701 

Data compression, 223 

data compression, 195, 737, 742 
data generating process, 1, 698 
data processing inequality, 204, 
215 

data tempering, 515 
data-driven MCMC, 474 
datatset shift, 695 

daydream phase, 446 

DBL, 609 

DBN, 150, 959 

DCGAN, 869 

DDIM, 844 

DDP, 1155 

DDPG, 1153 

de Finetti’s theorem, 66 

dead leaves, 892 

decision diagram, 1102 
decision nodes, 1102 

decision tree, 1104 
declarative approach, 187 
decoder, 604, 749, 885 
decomposable, 159, 171, 405 
decompose, 133 

decoupled EKF, 634 

deep autoregressive network, 123 
deep Bayesian learning, 609 
deep belief network, 150 

deep Boltzmann machine, 150 
deep Boltzmann network, 150 
deep CCA, 911 

deep deterministic policy gradient, 
1153 

deep ensembles, 620 

Deep Factors, 992 

deep fakes, 736, 873 

deep Gaussian process, 691 
deep generative model, 123 
deep generative models, 733 
deep image prior, 613 

Deep kernel learning, 686 
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deep latent Gaussian model, 749 

deep latent variable model, 749 

deep learning, 1, 593 

deep Markov model, 993 

deep neural network, 593 

deep PILCO, 1157 

deep Q-network, 1144 

deep state-space model, 992 

deep submodular function, 305 

deep unfolding, 401 

DeepAR, 992 

DeepGLO, 992 

DeepSSM, 992 

default prior, 90 

deformable parts model, 168 

degenerate kernel, 652 

degree of normality, 8 

degrees of freedom, 8, 77, 113 

deleted interpolation, 105 

delta function, 69 

delta method, 255 

delta VAE, 772 

demand forecasting, 985 

denoising autoencoder, 1056 

denoising diffusion GAN, 844 

denoising diffusion models, 833 

denoising diffusion probabilistic 
model, 840 

denoising diffusion probabilistic 
models, 824 

Denoising Score Matching, 821 

DenseFlow, 735 

density estimation, 737 

density model, 703 

density ratio estimation, 62 

Derivative free optimization, 281 

derivative function, 235 

derivative operator, 235 

detailed balance, 471 

detailed balance equations, 55 

determinantal point process, 345 

Determinantal point processes, 
1039 

determinantal projection point pro- 
cesses, 1041 

deterministic ADVI, 437 

deterministic annealing, 949 

deterministic inducing conditional, 
668 

deterministic policy gradient theo- 
rem, 1152 

deterministic training conditional, 
669 

deviance, 112 

DFO, 281 

DGP, 1, 691 

diagonal covariance matrix, 14 

diameter, 391 

DIC, 668 

dictionary learning, 933 

diffeomorphism, 794 

difference in differences, 1204 

differentiable CEM, 289 

differentiable simulators, 848 

differential dynamic programming, 
1155 

differential entropy, 210 
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diffuse prior, 90 

diffusion matrix, 493 

diffusion models, 733, 833 

diffusion process, 833 

diffusion term, 492, 492 

DifWave, 844 

digamma function, 424 

dilated convolution, 787 

diminishing returns, 302 

DINO, 1062 

direct coupling analysis, 146 

direct method, 1161 

directed acyclic graphs, 119 

directed Gaussian graphical model, 
124 

Dirichlet, 25 

Dirichlet distribution, 74 

Dirichlet process, 1001, 1003, 1009 

Dirichlet process mixture models, 
426 

discount factor, 1111, 1121, 1135 

discount parameter, 1022 

discrete task-agnostic learning, 
721 

discrete with probability one, 1012 

discriminative model, 545 

discriminative reranking, 345 

discriminator, 850 

disease mapping, 664 

disease transmission, 999 

disentangled, 760, 932, 1063 

disentangled representation learn- 
ing, 1046 

dispersion parameter, 39 

distill, 844 

distillation, 624 

distortion, 224 

distributed representation, 149, 
957 

distribution free, 555 

distribution shift, 695, 729, 1160 

distributional particles, 526 

distributional RL, 1146, 1153 

distributionally robust optimiza- 
tion, 706 

diverged, 489 

divergence metric, 56 

DLGM, 749 

DLM, 979 

DLVM, 749 

DM, 833 

DNN, 593 

do calculus, 1176 

do-calculus, 1217 

do-notation, 1180 

domain adversarial learning, 709 

domain drift, 718 

domain generalization, 637, 712 

domain randomization, 1159 

domain shift, 698 

domains, 712 

donor, 989 

Donsker Varadhan lower bound, 
221 

Donsker-Varadhan, 204 

dot product attention, 599 

dot product kernel, 649 


double DQN, 1146 

double loop algorithms, 395 
double machine-learning, 1189 
double Q-learning, 1144 

double sided exponential, 9 
doubly intractable, 158 

doubly reparameterized gradient 
estimator, 443 

doubly robust, 1163, 1190 
doubly stochastic, 430 
downstream, 741 

DPGM, 119 

DPPs, 1039 

DQN, 1144 

DRE, 62 

Dreamer, 1159 

dropout, 597 

DTC, 669 

dual EKF, 375 

DualDICE, 1165 

dueling DQN, 1146 

Dutch book, 67 

dyna, 1155 

dynamic Bayesian network, 959 
dynamic embedded topic model, 
924 

dynamic linear model, 966, 979 
dynamic programming, 324, 385, 
1123 

dynamic programming., 407 
dynamic topic model, 923 
dynamic VAE, 768 

dynamical variational autoen- 
coders, 992 


E step, 260, 261 

EA, 284 

earning while learning, 1111 

EB, 103 

EBM, 733, 813 

ECE, 549 

ECM, 267 

ECME, 267 

EDA, 287 

edge potentials, 153 

edit distance, 944 

effective dimensionality, 627 

effective sample size, 499, 501, 
523 

effects of causes, 191 

EI, 278 

eigenfunction, 652 

eigengap, 495 

eight schools, 100 

Einstein summation, 408 

einsum, 405, 408 

einsum networks, 181 

EKF, 359 

elastic weight consolidation, 633, 
723 

ELBO, 260, 327, 412, 751 

elementwise flow, 796 

eligibility traces, 1142 

elimination order, 403 

elite set, 285 

ELPD, 111 

EM, 260 
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EMNA, 288 
empirical Bayes, 103, 612 
empirical distribution, 206 
empirical Fisher, 248 
empirical MBR, 1102 
empirical risk, 546 
empirical risk minimization, 546 
empirical risk minimzation, 267 
emulate, 539 
encoder, 604 
End-task., 1069 
endogeneous, 188 
10 energy, 142 
__ energy based models, 733 
— energy disaggregation, 958 
12energy function, 284, 326, 813, 
1038 
13 energy score, 702 
14 energy-based model, 142 
~_ Energy-based models, 813 
15 EnKF, 374 
16 ensemble, 597, 1130 
—_ ensemble Kalman filter, 374 
17 entity resolution, 183 
18 entropy, 210, 260 
"entropy sampling, 1129 
— entropy search, 279 
20 environment, 1111, 1119 
` environments, 712 
21 EP, 447 
22 epidemiology, 520 
episodic task, 1120 
= epistemic uncertainty, 69, 552 
24 epistemological uncertainty, 1001 
EPLL, 891 
=epsilon-greedy, 1138 
26 epsiode, 1120 
equal odds, 1102 
— equal opportunity, 1102 
28 equilibrium distribution, 53 
29 ergodic, 55 
=~ ERM, 267, 546 
30 error correcting codes, 227, 398 
error correction, 195 
—error-correcting codes, 180 
32 ES-RNN, 991 
ESS, 499, 523 
— estimated potential scale reduction, 
34 499 
estimation of distribution, 287 
— estimation of multivariate normal 
36 algorithm, 288 
37 estimator, 65 
— Etsy, 945 
38 EUBO, 443 
Euler approximation, 503 
— Euler’s method, 488, 809 
40 evidence, 66, 108, 327, 751 
41 evidence lower bound, 260, 327, 
= 412, 751 
42 evidence maximization, 612 
evidence upper bound, 443 
— Evolution strategies, 289 
44 evolutionary algorithm, 284 
evolutionary programming, 287 
— evolutionary search, 281 
46 evolutionary strategies, 274 
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EWC, 633 

excess kurtosis, 9 

exchangeable, 138, 924 

exchangeable process, 1014 

exchangeable with, 97 

Exclusion Restriction, 1197 

exclusive KL, 201 

execution traces, 187 

exogenous, 188 

exp-sine-squared kernel, 647 

expanded parameterization, 889 

expectation backpropagation, 635 

expectation maximization, 260 

expectation propagation, 383, 447 

expected calibration error, 549 

expected complete data log likeli- 
hood, 261 

expected free energy, 1170 

expected improvement, 278 

expected LPPD, 111 

expected patch log likelihood, 891 

expected sufficient statistics, 136, 
261 

experience replace, 1144 

experience replay, 723 

explainability, 191 

explaining away, 82, 122, 127, 131, 
218 

explicit duration HMM, 954 

explicit layers, 602 

explicit probabilistic models, 847 

exploration bonus, 1115, 1139 

exploration-exploitation tradeoff, 
1106, 1112, 1138 

exponential cooling schedule, 512 

exponential dispersion family, 39 

exponential distribution, 11 

exponential family, 30, 40, 88 

exponential family factor analysis, 
912 

exponential family harmonium, 
149 

exponential family PCA, 912 

exponential family state-space 
model, 978 

Exponential linear unit, 595 

exponentiated quadratic, 644 

extended Kalman filter, 359, 360, 
976 

extended Kalman smoother, 361 

extended particle filter, 524 

extended RTS smoother, 361 

external field, 145 

external validity, 697 

extrapolation, 677 

extrinsic variables, 421 


-divergence, 56 

-Divergence Max-Ent IRL, 1172 
Facebook, 1002 

actor, 141 

actor analysis, 139 

actor graph, 178, 392, 399 
actor loading matrix, 139, 895 
actor of variation, 760 

actor rotations problem, 899 
actorial HMM, 440, 957 


factorization property, 130 

FAIRL, 1172 

fairness, 191, 1102 

Faithfulness, Fidelity, 1081 

family marginal, 136 

fan-in, 240 

fan-out, 237, 240 

fantasy data, 818 

Fast Fourier Transform, 342 

fast geometric ensembles, 618 

fast gradient sign, 726 

fast ICA, 931 

fast weights, 622 

FastSLAM, 533 

FB, 339 

feature induction, 152 

feature-based, 303 

feedback loop, 1113 

feedforward neural network, 602 

ferromagnetic, 143 

few-shot learning, 707 

Feynman-Kac, 513 

FFBS, 337 

FFG, 179 

FFJORD, 809 

FFNN, 602 

FGS, 726 

FIC, 670 

FID, 744 

fill-in edges, 404 

FiLM, 602 

filter, 595 

filter response normalization, 597 

filtering distribution, 333 

filtering SMC, 540 

filtering variational objective, 540 

FIM, 40 

fine-tuning, 1045 

finite horizon, 1111 

finite horizon problem, 1121 

finite state machine, 1120 

finite sum objective, 251 

finite-state Markov chain, 48 

first-order delta method, 255 

first-order logic, 184 

Fisher divergence, 820 

Fisher information, 92 

Fisher information matrix, 37, 40, 
92 

FITC, 669 

fitness, 285 

fitness proportionate selection, 285 

fitted value iteration, 1146 

FIVO, 540 

fixed effects, 589 

Fixed lag smoothing, 347 

fixed-form VI, 429 

fixed-lag smoothing distribution, 
333 

flat minima, 493, 625 

flow cytometry, 739 

folded over, 7 

folds, 109 

FOO-VB, 636 

fooling images, 727 

force, 506 

forest plot, 99 
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fork, 125 

Forney factor graph, 179 

Forney factor graphs, 178 

forward adversarial inverse RL, 
1172 

forward transfer, 722 

forward-mode automatic differenti- 
ation, 238 

forwards algorithm, 406, 944 

forwards filtering backwards sam- 

pling, 345 

filtering backwards 

smoothing, 337 

forwards kernel, 534 

forwards mapping, 883 

forwards process, 833 

forwards-backwards, 339, 407 

founder variables, 900 

Fourier basis, 933 

Fourier transform, 648 

Fréchet Inception Distance, 744 

fragmentation, 1024 

Frechet inception distance, 61 

free bits, 773 

free energy, 412 

free energy principle, 1170 

free-form VI, 415 

freeze-thaw algorithm, 281 

frequentist sampling distribution, 
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fully independent training condi- 
tional, 669 

function space, 656 

functional, 233 

functional causal model, 188 

fundamental problem of causal in- 
ference, 192 

funnel shape, 101 

funnel transformer, 769 
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Gumbel distribution, 255 
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Gumbel-Softmax distribution, 255 
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integrating out, 68 

inter-causal reasoning, 127 

interaction information, 217 

interactive multiple models, 379 

Interactivity., 1082 
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interpolator, 653 

intervention, 192 

interventions, 190, 1176 

Inteventional queries, 1180 

intrinsic uncertainty, 69 

intrinsic variables, 421 
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invariant causal prediction, 714 

invariant distribution, 53 

invariant risk minimization, 714 
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inverse chi-squared distribution, 
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inverse Gamma, 76, 565 

inverse Gamma distribution, 11 

inverse mass matrix, 487, 490 

inverse of a partitioned matrix, 17 

inverse optimal control, 1171 

inverse probability of treatment 
weighted estimator, 
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inverse probability theory, 65 
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inverse reinforcement learning, 
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Ising models, 476 
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iteratively reweighted least squares, 
563 

IWAE, 442 

IWAE bound, 442 
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Jacobian vector product, 823 
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judge fixed effects, 1198 
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junction tree algorithm, 407 
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Kalman Filter, 348 

Kalman filter, 346, 352, 964, 982 

Kalman filter algorithm, 21 
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Lagrange multipliers, 39 
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Laplace approximation, 325, 580 

Laplace distribution, 9 
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linear dynamical system, 345, 962 
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linear Gaussian state space model, 
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Markov assumption, 47, 785 
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Markov kernel, 47 
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Markov model of order n, 49 

Markov network, 141 
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masked language modeling, 1056 

matching, 1191 

Matern kernel, 645 

matrix determinant lemma, 807 

matrix inversion lemma, 17, 21, 
348 

matrix normal, 28 

matrix normal inverse Wishart, 
576 

matrix vector multiplication, 674 

max margin Markov networks, 170 

max marginals, 71, 389 

max-product belief propagation, 
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546 

maximum mean discrepancy, 58, 
59, 762, 855 

MBIE, 1139 

MBIE-EB, 1139 
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message passing algorithms, 385 

message passing schedule, 385 
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meta-data, 714 
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midi format, 789 

min-fill heuristic, 405 

min-max, 858 

min-max optimization problem, 
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min-weight heuristic, 405 

minibatch, 251 
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minimal I-map, 125 

minimal representation, 31 

minimal sufficient statistic, 216, 
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minimally informative prior, 90 

minimum Bayes risk, 1101 

minimum description length, 113 

minimum mean squared error, 
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minorize-maximize, 257 
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missing completely at random, 766 
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missing data mechanism, 766 

mixed effects model, 589 
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mixing matrix, 928 
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mixture model, 885 
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mixture of factor analysers, 902 
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mixture of Kalman filters, 527 

mixture proposal, 473 
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MMD VAE, 762 
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Mobius inversion formula, 219 
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mode collapse, 861 

mode connectivity, 627 
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model checking, 115 

model predictive control, 1154 

model-agnostic meta-learning, 717 

model-based approach, 1 

model-based RL, 1135, 1137, 
1153 
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Modularity, 1082 

MoG, 886 

molecular graph structure, 770 
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moment parameters, 16, 31 

moment projection, 201, 377 
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Monte Carlo approximation, 328 

Monte Carlo control, 1140 
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Monte Carlo estimation, 1140 
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Monte Carlo methods, 453 
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multi-information, 216 
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multi-level model, 96 
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multi-scale, 778 

multi-stage likelihood, 986 

multi-target tracking, 972 

multi-task learning, 711 

multiclass logistic regression, 577 
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Multinomial logistic regression, 
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multiplicative interactions, 601 

multiplicative layers, 601 
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multivariate Gaussian, 14 


multivariate linear regression, 575 
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music transformer, 789 
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mutual information, 214 
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17 No-U-Turn Sampler, 490 

~ node potentials, 153 

18 Noise Conditional Score Network, 

19 826 

~~ Noise Contrastive Estimation, 827 

2 noisy channel, 227 

21 noisy channel model, 944 

5. noisy nets, 1146 

=“ non-centered parameterization, 

23 101, 502, 592 

` non-contrastive representation 

24 learning, 1061 

25 non-descendants, 130 
non-factorial prior, 575 

— non-Gaussian SSM, 977 

27 non-linear squared flow, 797 

28 non-Markovian models, 514 

non-negative matrix factorization, 

29 916 

30 non-null recurrent, 55 

— non-parametric Bayesian, 1009 

31 non-parametric Bayesian models, 

32 884 

~~ non-parametric BP, 394, 533 

33 non-parametric models, 545 

34 non-parametrically efficient, 1190 

— non-stationary kernel, 649, 689 

35 non-terminals, 165 

36 nondecreasing, 1033 

“~~ noninformative, 90 

37 nonlinear dynamical system, 975 

38 nonlinear factor analysis, 914 

nonlinear Gaussian SSM, 975 

2= nonparametric copula, 992 

40 nonparametric models, 641 

“normal distribution, 7 

— normal factor graph, 179 

42 normal inverse chi-squared, 78 

3 normal inverse Gamma, 77, 565 

— Normal-Inverse-Wishart, 1016 

44 Normal-inverse-Wishart, 83 
normalization layers, 597 

~ normalized completely random 

46 measures, 1032 
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normalized occupancy distribution, 
1127 

normalized random measures 
(NRMs), 1032 

normalized stable process, 1033 

normalized weights, 462, 516 

normalizes, 793 

normalizing flow, 939 

Normalizing flows, 441 

normalizing flows, 733, 793 

not missing at random, 766 

noun phrase chunking, 164 

noun phrases, 164 

Nouveau VAE, 779 

novelty detection, 701 

NP-hard, 406 

NSSM, 977 

NTK, 691 

nuisance functions, 1189 

nuisance variables, 131, 390 

null hypothesis, 108 

numerical integration, 368, 453 

NUTS, 490 

NUV, 573 

NWJ lower bound, 221 

Nyström approximation, 667 


object detection, 167 

objective, 90 

observation function, 935 

observation model, 935, 937 

observation noise, 935, 963, 964 

observation overshooting, 995 

Occam factor, 113 

Occam’s razor, 906 

occasionally dishonest casino, 335 

occlusion, 167 

off-policy, 1143 

off-policy policy-gradient, 1162 

offline reinforcement learning, 
1161 

offspring, 285 

OGN, 273 

Olivetti faces dataset, 687 

on-policy, 1143 

one-armed bandit, 1111 

one-max, 287 

one-step-ahead predictive distribu- 
tion, 376 

one-to-one, 45 

online advertising system, 1112 

online Bayesian inference, 964 

online EM, 267 

online EWC, 637 

Online Gauss-Newton, 273 

online learning, 632, 723 

online structured Laplace, 637 

ontological uncertainty, 1001 

ontology, 1004 

OOD, 631, 701 

open class, 164 

open set recognition, 705 

open world, 974 

open world classification, 705 

open world recognition, 699 

open-universe probability models, 
187 


OpenGAN, 704 

opportunity cost, 1106 

optimal action-value function, 
1122 

optimal partial policy, 1125 

optimal policy, 1099, 1122 

optimal resampling, 528 

optimal state-value function, 1122 

optimal transport, 290 

optimism in the face of uncertainty, 
1114 

optimization problems, 233 

optimizer’s curse, 1143 

Optimus, 769 

oracle, 274 

ordered Markov property, 119, 130 

ordinal regression, 587 

Ornstein-Uhlenbeck process, 646 

orthodox statistics, 65 

orthogonal additive kernel, 682 

orthogonal Monte Carlo, 468 

orthogonal random features, 653 

OUPM, 187 

out-of-distribution detection, 701 

out-of-domain, 696 

outer product method, 147 

outlier detection, 701, 737 

outlier exposure, 701 

over-complete representation, 31 

overcomplete representation, 933 

overdispersed, 496 

overfitting, 546 

overlap, 1185 


Pólya urn, 1013 

PAC-Bayes, 547, 628 

padding, 596 

PageRank, 52 

paired data, 866 

pairwise Markov property, 155 

pairwise potentials, 1038 

panel data, 986, 1204 

parallel prefix scan, 342, 354, 676, 
691 

parallel tempering, 496, 512 

parallel trends, 1205 

parallel wavenet, 805 

parameter learning, 132 

parameter tying, 47, 97, 181 

parametric Bayesian model, 1009 

parametric model, 1009 

parametric models, 545 

parametric prior, 1010 

parent, 285 

parents, 119 

Pareto distribution, 12 

pareto index, 12 

Pareto smoothed importance sam- 
pling, 112 

parity check bits, 227 

part of speech, 164, 927 

part of speech tagging, 173 

parti, 792 

partial least squares, 910 

partially directed acyclic graph, 
174 
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partially observable Markov deci- 
sion process, 1120 

partially observed data, 262 

partially observed Markov model, 
935 

partially pooled model, 588 

particle BP, 394, 533 

particle filter, 515 

particle filtering, 329, 383, 513, 
1020 

particle impoverishment, 517 

particle smoothing, 533 

partition function, 30, 142, 813 

partition of the integers, 1014 

parts, 916 

patchGAN, 874 

path consistency learning, 1166 

path degeneracy, 520 

path diagram, 189 

path sampling, 443 

pathwise derivative, 253 

patience, 431 

pattern completion, 147 

PBIL, 287 

PBP, 616, 635 

PCFG, 165 

PCL, 1166 

PDAG, 174 

peaks function, 510 

peeling algorithm, 401 

PEGASUS, 1157 

per-decision importance sampling, 
1162 

per-sample ELBO, 439 

per-step importance ratio, 1162 

per-step regret, 1118 

perceptual aliasing, 883 

perceptual distance metrics, 744 

perfect elimination ordering, 405 

perfect information, 1108 

perfect intervention, 190 

perfect map, 170 

period, 54 

periodic kernel, 647, 661 

permuted MNIST, 721 

perplexity, 213, 742 

persistent contrastive divergence, 
819 

persistent variational inference, 
158 

personalized recommendations, 72 

perturb-and-MAP, 409 

perturbation, 235 

PETS, 1157 

PGD, 726 

PGMs, 119 

phase space, 487 

phase transition, 144 

phi-exponential family, 35 

phone, 956 

phosphorylation state, 739 

Picard—Lindelof theorem, 809 

pictorial structure, 168 

PILCO, 1156 

Pilot Studies, 1091 

pinball loss, 557 

pipe, 125 
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Pitman-Koopman-Darmois theo- 
rem, 38, 216 

pix2pix, 874 

pixelCNN, 787 

pixelCNN++4, 788 

pixelRNN, 788 

PixelSNAIL, 782 

placebo, 1112 

planar flow, 808 

PlaNet, 1159 

planning, 1123, 1135 

planning horizon, 1154 

plant, 1120 

plates, 138 

Platt scaling, 550 

platykurtic, 9 

PLS, 910 

plug-in estimator, 1161 

plugin approximation, 69 

plutocratic, 12 

PoE, 814 

point estimate, 65, 68 

point process, 1032 

Poisson, 6 

Poisson process, 1032 

Poisson regression, 561 

policy, 1110, 1119 

policy evaluation, 1123, 1125 

policy gradient, 1137 

policy gradient theorem, 1147 

policy improvement, 1125 

policy iteration, 1125 

policy optimization, 1123 

policy search, 1137, 1146 

Polyak-Ruppert averaging, 618 

polymatroid function, 302 

polynomial kernel, 649, 650 

polynomial regression, 570 

polysemy, 920 

polytrees, 389 

POMDP, 1120 

pool-based-sampling, 1127 

pooled, 97 

pooling layer, 596 

population, 284 

population-based incremental 
learning, 287 

position-specific scoring matrix, 
943 

positive definite, 29 

positive definite kernel, 643 

positive phase, 159 

possible worlds, 182, 185 

post-order, 387 

posterior collapse, 426, 771 

posterior distribution, 66 

posterior expected loss, 1099 

posterior inference, 66, 333 

posterior marginal, 131 

posterior mean, 1100 

posterior predictive check, 115 

posterior predictive distribution, 
68 

posterior-predictive p-value, 115 

potential energy, 487 

potential function, 22, 141 

potential outcome, 192 


potential outcomes, 1180 

Potts model, 146 

Potts models, 476 

power EP, 450 

power law, 12 

power posterior, 624 

PPCA, 900 

PPL, 187 

PPO, 1152 

pre-order, 387 

pre-train and fine-tune, 707 

precision, 7, 75 

precision matrix, 16, 34 

precision-weighted mean, 16 

preconditioned SGLD, 504 

predict-update, 335 

prediction, 193 

prediction step, 336 

predictive coding, 1170 

predictive distribution, 333 

predictive entropy search, 279 

predictive model, 545 

predictive sparse decomposition, 

933 

predictive state representation, 

951 

predictive uncertainty, 68 

prequential analysis, 110 

prescribed probabilistic models, 

847 

pretext tasks, 1056 

prevalence shift, 699 

Price’s theorem, 249, 254 

primitive nodes, 240 

primitive operations, 238 

principle of insufficient reason, 91 

prior distribution, 66 

prior linearization filter, 370 

prior network, 620 

prior predictive distribution, 569 

prior shift, 699 

prioritized experience replay, 1146 

Probabilistic backpropagation, 

635 

probabilistic backpropagation, 616 

probabilistic circuit, 181 

probabilistic ensembles with trajec- 

tory sampling, 1157 

Probabilistic graphical models, 
119 

istic graphical models, 733 

istic LSA, 920 

istic principal components 
analysis, 900 

probabilistic programming lan- 
guage, 187 

istic soft logic, 186 

probability integral transform, 46 

probability matching, 1116 

probability of improvement, 277, 

1129 

probability simplex, 25 

probit approxmation, 581 

probit function, 584 

probit link function, 562 

probit regression, 584 

procedural approach, 187 


a 


probab 
probab 
probab 


= me 


a 


probab 


s ae ae 


1236 


process noise, 633, 935, 963 
product density function, 1040 
product of experts, 149, 764, 814 
product partition model, 961 
production rules, 165 
profile HMM, 944 
projected gradient descent, 726 
projecting, 376 
projection, 175 
projection pursuit, 931 
Projection-weighted CCA 
(PWCCA), 1048 
10 prompt, 707 
|. prompt tuning, 707 
—propensity score, 1187 
12 propensity score matching, 1192 
proper scoring rule, 548, 851 
13 Properties., 1070 
14 Prophet, 991 
15 proposal distribution, 329, 458, 
= 460, 470, 511 
16 propose, 470 
protein sequence alignment, 942 
— protein structure prediction, 146 
18 protein-protein interaction net- 
19 works, 999 
— prototypical networks, 718 
20 proximal policy optimization, 
21 1152 
— pseudo counts, 68, 72 
22 pseudo inputs, 667 
23 pseudo likelihood, 160, 160 
= pseudo random number generator, 
24 456 
25 PSIS, 112 
— pure exploration, 1118 
26 pushforward, 794 
pushing sums inside products, 401 
— Pyro, 187 
28 


29 Q-learning, 1137, 1143 
30 QKF, 369 
== QT-Opt, 1123 
31 quadratic approximation, 325 
2 quadratic kernel, 649, 652 
— quadratic loss, 1100 
33 quadrature, 453 
4 quadrature Kalman filter, 369 
— quadratures, 368 
35 Qualitative Studies, 1090 
6 quantile loss, 557 
— quantile regression, 557, 992 
37 quantization, 211 
38 Quasi Monte Carlo, 467 
— quasi-Newton EM algorithm, 267 
39 queries, 274 
40 query, 599 
— query by committee (QBC), 1130 
41 query nodes, 131 


42 
R-hat, 499 
43 adar, 972 
44 radial basis function, 644 
5 radon, 590 
— rainbow, 1146, 1153 
46 random accelerations model, 964 
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random assignment with non- 
compliance, 1198 

random effects, 589 

random finite sets, 975 

random Fourier features, 653 

random measure, 1027 

random prior deep ensemble, 620 

random restart, 282 

random restart hill climbing, 282 

random search, 283 

random walk kernel, 650 

random walk Metropolis, 329, 472 

random walk on the integers, 55 

random walk proposal, 511 

randomized control trials, 1182 

Randomized QMC, 467 

Rao-Blackwellisation, 464 

Rao-Blackwellised particle filtering, 
527 

rare event, 538 

raster scan, 787 

rate, 224 

rate distortion curve, 224, 760 

rational quadratic, 648, 692 

rats, 99 

RBF, 644 

RBM, 149 

RBPF, 527 

Real NVP, 810 

real-time dynamic programming, 
1125 

receeding horizon control, 1154 

recognition network, 439, 749 

recognition weights, 929 

recombination, 285 

recommender system, 182 

reconstruction error, 224, 704 

record linkage, 183 

Rectified linear unit, 595 

recurrent, 54 

recurrent layer, 600 

recurrent neural network, 606, 785 

recurrent neural networks, 600 

recurrent SSM, 994 

recursive, 188 

recursive least squares, 347, 633, 
964 

redundancy, 218, 227 

reference prior, 96 

refractoriness, 1036 

regime switching Markov model, 
939, 971 

Regression discontinuity designs, 
1221 

regression estimator, 1161 

regression model, 545 

regret, 724, 1108, 1117, 1118 

regular, 37, 54 

regularization methods, 723 

regularized evolution, 285 

rehearsal, 723 

REINFORCE, 252, 430, 1137, 
1148 

Reinforcement learning, 1135 

reject action, 1100 

rejection sampling, 457 

relational probability models, 181 


relational UGMs, 183 

relational uncertainty, 183 

relative entropy, 197 

relative risk, 664 

relevance network, 222 

relevance vector machine, 575 

reliability diagram, 549 

Renewal processes, 1036 

reparameterization gradient, 253 

reparameterization trick, 432, 615, 
753 

reparameterized VI, 432 

reparametrization trick, 861 

repeated trials, 65 

representation, 228 

Representation learning, 741, 
1043 

representation learning, 677 

Representational similarity analy- 
sis (RSA), 1046 

representer theorem, 660 

Reproducing Kernel Hilbert Space, 
659 

reproducing property, 659 

resample, 519 

resample-move, 533 

residual belief propagation, 392, 
396 

residual block, 806 

residual connections, 596, 806 

residual error, 348 

residual flows, 806 

residues, 146 

ResNet, 603 

resource allocation, 1112 

response surface model, 274 

responsibility, 887 

restless bandit, 1112 

restricted Boltzmann machine, 149 

return, 1121 

reverse process, 833 

reverse-mode automatic differentia- 
tion, 239 

reversible jump MCMC, 506, 906 

reward, 1105, 1110 

reward function, 1120 

reward model, 1119 

reward-to-go, 1121 

reweighted wake-sleep, 444, 445 

RFF, 653 

rich get richer, 426, 1014 

ridge regression, 564, 567, 658 

Riemann Manifold HMC, 491 

Riemannian manifold, 245 

risk, 1099 

RJMCMC, 506 

RKHS, 659 

RL, 1135 

RLS, 964 

RM-HMC, 491 

RMSprop, 273, 273 

RNA-Seq, 985 

RNADE, 786 

RNN, 606 

RNN-HSMM, 955 

robust IEKF, 361 

robust IEKS, 362 


Draft of “Probabilistic Machine Learning: Advanced Topics”. August 15, 2022 


B IO lœ IN ID o [A IW N Ie 


= 
m. 


N JN j= je j= Je je j= IR j= 
IR lele IB le lale le Is | 


INDEX 


robust optimization, 729 
robust priors, 89 

robust regression, 664 
robustness, 705 

robustness analysis, 89 

roll call, 917 

roulette wheel selection, 285 
row stochastic matrix, 121, 937 
RPMs, 181 

RStanARM, 589 

RTS Smoother, 352 

RTSS, 352 

Rubin Causal Model, 193 

run length, 960 
Russian-roulette estimator, 807 
RVI, 432 


SAC, 1168 

safe policy iteration, 1151 

SAGA-LD, 505 

SAGAN, 869 

sample diversity, 742 

sample inefficient, 1137 

sample quality, 742 

sample standard deviation, 79 

sampling distribution, 65, 567 

sampling with replacement, 6 

SARSA, 1137, 1143 

satisfying assignment, 406 

SBEED, 1166 

scale-invariant prior, 95 

scaled inverse chi-squared, 11 

scaled likelihood trick, 940 

scaling-binning calibrator, 550 

scatter matrix, 82 

SCFGs, 956 

Schrédinger bridge, 844 

Schur complement, 17, 18 

SCM, 188 

score, 820 

score function, 40, 42, 816 

score function estimator, 252, 430, 
877 

score matching, 820 

score-based generative models, 
824, 833 

seasonality, 981 

second order EKF, 361 

second-order delta method, 255 

segment, 335 

segmental HMM, 955 

selection bias, 128, 700 

selection function, 285 

selective prediction, 704, 1100 

self attention, 788 

Self Attention GAN, 869 

self-attention, 600 

self-normalized importance sam- 
pling, 461 

self-train, 711 

semantic network, 1004 

semantic segmentation, 163, 166 

semi-amortized VI, 440 

semi-Markov model, 954 

semi-Markovian SCM, 188 

semi-parametric model, 656 

semi-parametrically efficient, 1189 
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semi-supervised learning, 711, 767 

semilocal linear trend, 980 

sensible PCA, 900 

sensitive attribute, 1102 

sensitivity analysis, 89 

sensor fusion, 21 

sequence memoizer, 1024 

sequence-to-sequence, 737, 737 

sequential Bayesian inference, 329, 
632 

sequential Bayesian updating, 335, 
964 

sequential decision problem, 1110 

sequential importance sampling, 
517 

sequential importance sampling 
with resampling, 518 

sequential model-based optimiza- 
tion, 275 

sequential Monte Carlo, 329, 513 

sequential VAE, 768 

sFA, 917 

SFE, 252 

SG-HMC, 505 

SGD, 243 

SGLD, 492, 504 

SGLD-Adam, 504 

SGLD-CV, 504 

SGPR, 673 

SGRLD, 504 

shaded nodes, 138 

shared trunk network, 712 

sharp minima, 625 

Sherman-Morrison-Woodbury for- 
mula, 17 

shift equivariance, 596 

shift invariance, 596 

shift-invariant kernels, 644 

shooting, 1155 

shortcut features, 696 

shortest path problems, 1125 

shrinkage, 76, 99 

sigma point filter, 363 

sigma points, 363, 368 

sigmoid, 577 

sigmoid belief net, 123 

signal to noise ratio, 834 

signed measure, 214 

silent state, 945 

SimCLR, 1059 

simple regret, 1118 

simplex factor analysis, 917 

Simplicity, 1083 

Simulability, 1082 

Simulated annealing, 284 

simulated annealing, 274, 510 

simulation-based inference, 539, 
848 

simultaneous localization and map- 
ping, 530 

single site updating, 480 

single world intervention graph, 
193 

singular statistical model, 114 

Singular vector CCA (SVCCA), 
1048 

SIS, 517 


SISR, 518 

site potential, 447 

SIXO, 540 

sketch-rnn, 770 

SKI, 675, 676 

SKIP, 676 

skip connections, 596, 773, 775 

skip-chain CRF, 165 

skip-VAE, 773 

SLAM, 530 

SLDS, 971 

sleep phase, 445 

slice sampling, 483 

sliced Fisher divergence, 822 

Sliced Score Matching, 822 

sliding window detector, 167 

slippage, 532 

slot machines, 1111 

slow weights, 622 

SLR, 369 

SLS, 282 

SMBO, 275 

SMC, 329, 513 

SMC sampler, 513 

SMC samplers, 533 

SMC?, 539 

SMC-ABC, 539 

SMILES, 770 

smoothed Bellman error embed- 
ding, 1166 

smoothing, 337 

smoothing distribution, 333 

snapshot ensembles, 618 

SNGP, 617 

Sobol sequence, 467 

social networks, 999 

soft actor-critic, 1168 

soft clustering, 887, 1001 

soft constraint, 814 

soft Q-learning, 1169 

soft-thresholding, 426 

softmax, 32 

softmax function, 578 

Softplus, 595 

SOLA, 637 

SOR, 669 

Soundness, 1083 

source coding, 195, 223 

source coding theorem, 223 

source distribution, 695 

source distributions, 711 

source domain, 874 

space filling, 467 

SPADA, 974 

sparse, 26, 570 

sparse Bayesian learning, 573 

sparse coding, 933 

sparse factor analysis, 900 

sparse GP, 668 

sparse GP regression, 673 

sparse variational GP, 671 

sparsity promoting priors, 612 

spectral density, 648, 684 

spectral estimation, 951 

spectral estimation method, 969 

spectral mixture kernel, 684 

spectral mixture kernels, 648 
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speech-to-text, 736 
spelling correction, 944 
sphere the data, 930 
spherical covariance matrix, 14 
spherical cubature integration, 369 
spherical K-means algorithm, 28 
spherical topic model, 28 
sphering, 930 
spike and slab, 889 
spike-and-slab, 570 
spin, 143 
splines, 798 
10 split conformal prediction, 556 
__ split MNIST, 721 
11 split-Rhat, 499 
12 spurious correlations, 696 
~~ SQF-RNN, 992 
= square root filter, 352 
14 square root information filter, 352 
~_ square-integrable functions, 659 
15 squared error, 1100 
16 squared exponential, 644 
798, 570 
17 SSID, 969 
18 SSM, 935 
“Stability, 1081 
19 stacking, 622 
20 standard Brownian motion, 1034 
5, standard error, 68, 455 
-standard error of the mean, 80 
22standard Normal, 7 
state estimation, 333 
® state of nature, 1099 
24state transition diagram, 48 
state-space model, 935 
= state-space models, 333 
26 state-value function, 1122 
stateful, 600 
= static calibration error, 550 
28 stationary, 47 
stationary distribution, 52, 53 
= stationary kernels, 644 
30 statistical estimand, 1179 
statistical linear regression, 369 
— statistical parity, 1102 
32 Statistics, 65 
steepest ascent, 282 
— steepest descent, 243 
34 stepping out, 484 
stepwise EM, 267 
—=stick-breaking construction, 1012 
36 stick-breaking process, 1013 
sticking the landing, 433 
— sticky, 473 
38 stochastic approximation, 266, 274 
stochastic approximation EM, 266 
— stochastic automaton, 48 
40 stochastic averaged gradient accel- 
41 eration, 505 
— stochastic bandit, 1111 
42 stochastic block model, 1000 
stochastic computation graph, 256 
~ stochastic context free grammars, 
44 956 
45 stochastic differential equations, 
T 833 
46 stochastic EP, 450 
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stochastic gradient descent, 243 

Stochastic Gradient Langevin De- 
scent, 492 

stochastic gradient Langevin dy- 
namics, 504 

Stochastic Gradient Riemannian 
Langevin Dynamics, 
504 

stochastic hill climbing, 282 

stochastic Lanczos quadrature, 
674 

stochastic local search, 281, 282, 
284 

stochastic matrix, 48 

stochastic meta descent, 169 

Stochastic MuZero, 1155 

stochastic process, 1009 

stochastic relaxation, 274 

stochastic RNN, 993 

stochastic variance reduced gradi- 
ent, 504 

stochastic variational inference, 
430, 674 

stochastic video generation, 997 

stochastic volatility model, 978 

stochastic weight averaging, 618 

stochastically ordered, 1013 

stop gradient, 781 

stop words, 922 

straight-through estimator, 256, 
780 

stratified resampling, 522 

streaks, 334 

stream-based selective sampling, 
1127 

streaming variational Bayes, 636 

strict overlap, 1185 

strictly monotonic scalar function, 
797 

string kernel, 650 

structural causal models, 188, 190, 


1175 

structural equation model, 176, 
189 

structural support vector machine, 
170 


structural time series, 979 

structural zeros, 153 

structured kernel interpolation, 
675 

structured mean field, 440 

structured prediction, 162 

Structured Prediction Energy Net- 
works, 170 

structured prediction model, 1101 

STS, 979 

Student distribution, 8 

Student network, 121 

student network, 122, 172, 401 

Student t distribution, 8 

style transfer, 875 

sub-Gaussian, 9 

subjective probability, 119 

Submodular, 301 

submodular, 1132 

subphones, 956 

Subscale Pixel Network, 788 


subset of regressors, 668 
subspace identification, 969 
subspace neural bandit, 1114 
sufficient, 228 

sufficient statistic, 46, 216 
sufficient statistics, 30, 30 
sum of squares, 83 
sum-product algorithm, 387 
sum-product networks, 181 
SupCon, 1059 
super-Gaussian, 9 
supervised PCA, 909 
surjective, 45 

surrogate assisted EA, 287 
surrogate function, 257, 274 
survival of the fittest, 519 
susbet-of-data, 666 

SUTVA, 192 

SVG, 997 

SVGP, 671 

SVI, 430 

SVRG-LD, 504 

SWA, 618 

SWAG, 619 

Swendsen Wang, 485 

SWIG, 193 

Swish, 595 

swiss roll, 824 

switching linear dynamical system, 


527, 971 
Sylvester flow, 808 
symamd, 405 


symmetric, 470 
synchronous updates, 396 
synergy, 218 

syntactic sugar, 138 
synthetic control, 989 
Synthetic controls, 1221 
systems biology, 1006 
systems identification, 967 
systolic array, 391 


T5, 769 

tabu search, 282 

tabular representation, 1120 

tacotron, 787 

TAN, 140 

target aware Bayesian inference, 
462 

target distribution, 457, 460, 513, 
695, 711 

target domain, 874 

target function, 460 

target network, 1145 

target policy, 1161 

targeted attack, 726 

TASA corpus, 920 

task, 721 

task incremental learning, 721 

task interference, 712 

task-aware learning, 721 

Taylor series, 326 

Taylor series expansion, 359 

TD, 1141 

TD error, 1137, 1141 

TD(A), 1142 

TD3, 1153 
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telescoping sum, 838 

temperature, 493 

temperature scaling, 550 

tempered posterior, 624 

tempering, 637 

template, 182 

templates, 892 

temporal difference, 1137, 1141 

tensor decomposition, 951 

tensor train decomposition, 676 

TENT, 711 

terminal state, 1120 

terminals, 165 

test and roll, 1105 

test statistics, 115 

test time adaptation, 710 

text generation, 737 

text to speech, 787 

text-to-image, 736 

text-to-speech, 876 

the deadly triad, 1166 

thermodynamic integration, 443 

thermodynamic variational objec- 
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thin shell, 15 

thinning theorem, 1035 

Thompson sampling, 279, 1116 

threat model, 728 

tilted distribution, 448 

time reversal, 534 

time reversible, 55 

time series forecasting, 979 

time update step, 348 

time-invariant, 47 

time-series forecasting, 176 

Toeplitz, 676 

top-down inference model, 775 

topic, 1024 

topic model, 919 

topic vector, 919 

topic-RNN, 927 

topological order, 130 

topological ordering, 119 

total correlation, 216, 760 

total derivative, 254 

total regret, 1118 

total variation distance, 61 

tournament selection, 285 

trace plot, 496 

trace rank plot, 497 

traceback, 343, 389 

track, 972 

tracking, 346 

tractable substructure, 440 

trajectory, 1135 

trankplot, 497 

trans-dimensional MCMC, 506 

transductive active learning, 554 

transductive learning, 701 

transfer learning, 707, 1044 

transformer, 606, 607 

transformer VAE, 769 

transformers, 769 

transient, 55 

transition, 1119 

transition function, 47, 935 

transition kernel, 47 
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transition matrix, 48, 49 
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translation-invariant prior, 95 

Translucence, 1082 

Transparency, 1079 

transportable, 697 

treatment, 1105 
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two-filter smoothing, 339, 354 
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two-sample testing, 701 
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type II maximum likelihood, 103 
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value iteration, 1124 
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